Source code for psynlig.scatter

# Copyright (c) 2020, Anders Lervik.
# Distributed under the MIT License. See LICENSE for more info.
"""A module for generating scatter plots of variables."""
from itertools import combinations
import warnings
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # pylint: disable=unused-import
from scipy.special import comb
from .colors import generate_class_colors, generate_colors
from .common import (
    add_xy_line,
    add_trendline,
    create_fig_and_axes,
    iqr_outlier,
    get_figure_kwargs,
)


_WARNING_MAX_PLOTS = (
    'This will generate {0} plots. If you want to generate '
    'all these plots, rerun the function with the '
    'argument "max_plots={0}".'
)


[docs]def create_scatter_legend(axi, color_labels, class_names, show=False, **kwargs): """Generate a legend for a scatter plot with class labels. Parameters ---------- axi : object like :class:`matplotlib.axes.Axes` The axes we will add the legend for. color_labels : dict of objects like :class:`numpy.ndarray` Colors for the different classes. color_names : dict of strings Names for the classes. show : boolean, optional If True, we will add the legend here. kwargs : dict, optional Additional arguments passed to the scatter method. Used here to get a consistent styling. Returns ------- patches : list of objects like :class:`matplotlib.artist.Artist` The items we will create a legend for. labels : list of strings The labels for the legend. """ patches, labels = [], [] for key, val in color_labels.items(): patches.append( axi.scatter([], [], color=val, **kwargs) ) if class_names is not None: label = class_names.get(key, key) else: label = key labels.append(label) if show: axi.legend(patches, labels, ncol=1) return patches, labels
[docs]def plot_scatter(data, xvar, yvar, axi=None, xlabel=None, ylabel=None, class_data=None, class_names=None, highlight=None, cmap_class=None, **kwargs): """Make a 2D scatter plot of the given data. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we are plotting. xvar : string The column to use as the x-variable. yvar : string The column to use as the y-variable. xlabel : string, optional The label to use for the x-axis. If None, we will use xvar. ylabel : string, optional The label to use for the y-axis. If None, we will use yvar. axi : object like :class:`matplotlib.axes.Axes`, optional An axis to add the plot to. If this is not provided, a new axis (and figure) will be created here. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings A mapping from the class data to labels/names. highlight : list of integers, optional This can be used to highlight certain points in the plot. cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. kwargs : dict, optional Additional settings for the plotting. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure containing the plot. axi : object like :class:`matplotlib.axes.Axes` The axis containing the plot. patches : list of objects like :class:`matplotlib.artist.Artist` The items we will create a legend for. labels : list of strings The labels for the legend. """ patches, labels = [], [] color_class, color_labels, idx_class = generate_class_colors( class_data, cmap=cmap_class ) fig = None if axi is None: fig_kw = get_figure_kwargs(kwargs) fig, axi = plt.subplots(**fig_kw) if xvar is None: axi.set(xlabel='Data point no.', ylabel=yvar) xdata = np.arange(len(data[yvar])) else: xlabel = xvar if xlabel is None else xlabel ylabel = yvar if ylabel is None else ylabel axi.set(xlabel=xlabel, ylabel=ylabel) xdata = data[xvar] ydata = data[yvar] if class_data is None: axi.scatter(xdata, ydata, **kwargs.get('scatter', {})) else: for class_id, idx in idx_class.items(): axi.scatter( xdata[idx], ydata[idx], color=color_class[class_id], **kwargs.get('scatter', {}), ) patches, labels = create_scatter_legend( axi, color_labels, class_names, **kwargs.get('scatter', {}), ) if highlight is not None: scat = axi.scatter( xdata[highlight], ydata[highlight], **kwargs.get('scatter-outlier', {}), ) patches.append(scat) labels.append(scat.get_label()) return fig, axi, patches, labels
[docs]def generate_1d_scatter(data, variables, class_data=None, class_names=None, nrows=None, ncols=None, sharex=False, sharey=False, show_legend=True, outliers=False, cmap_class=None, **kwargs): """Generate 1D scatter plots from the given data and variables. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we will plot here. variables : list of strings The variables we will generate scatter plots for. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings, optional A mapping from the class data to labels/names. nrows : integer, optional The number of rows to use in a figure. ncols : integer, optional The number of columns to use in a figure. sharex : boolean, optional If True, the scatter plots will share the x-axis. sharey : boolean, optional If True, the scatter plots will share the y-axis. show_legend : boolean, optional If True, we will create a legend here and show it. outliers : boolean, optional If True, we will try to mark outliers in the plot. cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. kwargs : dict, optional Additional arguments used for the plotting. Returns ------- figures : list of objects like :class:`matplotlib.figure.Figure` The figures containing the plots. axes : list of objects like :class:`matplotlib.axes.Axes` The axes containing the plots. """ nplots = len(variables) figures, axes = create_fig_and_axes( nplots, nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, **kwargs, ) outlier_points = {} bounds = [{}, {}] if outliers: _, outlier_points, bounds = iqr_outlier(data, variables) fig = None for i, yvar in enumerate(variables): show_legend = False if axes[i].figure != fig: fig = axes[i].figure show_legend = True highlight = None if outliers: highlight = outlier_points.get(yvar, None) _, _, patches, labels = plot_scatter( data, None, yvar, axi=axes[i], class_data=class_data, class_names=class_names, highlight=highlight, cmap_class=cmap_class, **kwargs, ) if outliers: lower = bounds[0].get(yvar, None) upper = bounds[1].get(yvar, None) if lower is not None: axes[i].axhline(y=lower, ls=':', color='#262626') if upper is not None: axes[i].axhline(y=upper, ls=':', color='#262626') if show_legend and patches and labels: axes[i].legend(patches, labels) return figures, axes, outlier_points
[docs]def generate_2d_scatter(data, variables, class_data=None, class_names=None, nrows=None, ncols=None, sharex=False, sharey=False, show_legend=True, xy_line=False, trendline=False, cmap_class=None, shorten_variables=False, **kwargs): """Generate 2D scatter plots from the given data and variables. This method will generate 2D scatter plots for all combinations of the given variables. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we will plot here. variables : list of strings The variables we will generate scatter plots for. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings, optional A mapping from the class data to labels/names. nrows : integer, optional The number of rows to use in a figure. ncols : integer, optional The number of columns to use in a figure. sharex : boolean, optional If True, the scatter plots will share the x-axis. sharey : boolean, optional If True, the scatter plots will share the y-axis. show_legend : boolean, optional If True, we will create a legend here and show it. xy_line : boolean, optional If True, we will add a x=y line to the plot. trendline : boolean, optional If True, we will add a trend line to the plot. cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. kwargs : dict, optional Additional arguments used for the plotting. Returns ------- figures : list of objects like :class:`matplotlib.figure.Figure` The figures containing the plots. axes : list of objects like :class:`matplotlib.axes.Axes` The axes containing the plots. """ nplots = comb(len(variables), 2, exact=True) figures, axes = create_fig_and_axes( nplots, nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, **kwargs, ) fig = None for i, (xvar, yvar) in enumerate(combinations(variables, 2)): # We do not want to repeat the legend in all subplots: show_legend_ax = False if axes[i].figure != fig: fig = axes[i].figure show_legend_ax = True xlabel = None ylabel = None if shorten_variables: if len(xvar) > 5: xlabel = xvar[:3] + '...' if len(yvar) > 5: ylabel = yvar[:3] + '...' _, _, patches, labels = plot_scatter( data, xvar, yvar, axi=axes[i], xlabel=xlabel, ylabel=ylabel, class_data=class_data, class_names=class_names, cmap_class=cmap_class, **kwargs, ) if xy_line: line_xy = add_xy_line(axes[i], alpha=0.7, color='black') patches.append(line_xy) labels.append('x = y') if trendline: line_trend = add_trendline(axes[i], data[xvar], data[yvar], alpha=0.7, ls='--', color='black') patches.append(line_trend) labels.append('y = a + bx') if show_legend and show_legend_ax and patches and labels: axes[i].legend(patches, labels) return figures, axes
[docs]def plot_3d_scatter(data, xvar, yvar, zvar, class_data=None, class_names=None, cmap_class=None, **kwargs): """Make a 3D scatter plot of the given data. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we are plotting. xvar : string The column to use as the x-variable. yvar : string The column to use as the y-variable. zvar : string The column to use as the z-variable class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings, optional A mapping from the class data to labels/names. cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. kwargs : dict, optional Additional arguments used for the plotting. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure containing the plot. axi : object like :class:`matplotlib.axes.Axes`, optional The axis containing the plot. """ color_class, color_labels, idx_class = generate_class_colors( class_data, cmap=cmap_class ) fig = plt.figure() axi = fig.add_subplot(111, projection='3d') axi.set_xlabel(xvar, labelpad=15) axi.set_ylabel(yvar, labelpad=15) axi.set_zlabel(zvar, labelpad=15) if class_data is None: axi.scatter(data[xvar], data[yvar], data[zvar], **kwargs) else: for class_id, idx in idx_class.items(): axi.scatter( data[xvar][idx], data[yvar][idx], data[zvar][idx], color=color_class[class_id], **kwargs ) create_scatter_legend( axi, color_labels, class_names, show=True, **kwargs ) fig.tight_layout() return fig, axi
[docs]def generate_3d_scatter(data, variables, class_data=None, class_names=None, max_plots=5, **kwargs): """Generate 3D scatter plots from the given data and variables. This method will generate 3D scatter plots for all combinations of the given variables. Note that if the number of plots is large, then no plots will be generated and a warning will be issued. The maximum number of plots to create can be set with the parameter `max_plots` Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we will plot here. variables : list of strings The variables we will generate scatter plots for. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings, optional A mapping from the class data to labels/names. max_plots : integer, optional The maximum number of plots to create. kwargs : dict, optional Additional arguments used for the plotting. Returns ------- figures : list of objects like :class:`matplotlib.figure.Figure` The figures created here. axes : list of objects like :class:`matplotlib.axes.Axes` The axes created here. """ figures = [] axes = [] if len(variables) < 3: raise ValueError( 'For generating 3D plots, at least 3 variables must be provided.' ) nplots = comb(len(variables), 3, exact=True) if nplots > max_plots: msg = _WARNING_MAX_PLOTS.format(nplots) warnings.warn(msg) return figures, axes for (xvar, yvar, zvar) in combinations(variables, 3): figi, axi = plot_3d_scatter( data, xvar, yvar, zvar, class_data=class_data, class_names=class_names, **kwargs ) figures.append(figi) axes.append(axi) return figures, axes
[docs]def scatter_1d_flat(data, class_data=None, class_names=None, scaler=None, add_average=False, add_lines=False, cmap_lines=None, cmap_class=None, split_class=False, scatter_settings=None, line_settings=None): """Make a flat plot of several variables. Here, the points on the x-axis are the variables, while the y-values are points for each data series. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we are plotting. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings A mapping from the class data to labels/names. scaler : callable, optional A function that can be used to scale the variables. add_average : boolean, optional If True, we will show the averages for each variable. add_lines : boolean, optional If True, we will show lines for each "measurement". cmap_lines : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for lines. cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. split_class : boolean, optional If True, the plot with class information will be split into one plot for each class. scatter_settings : dict, optional Additional settings for the scatter plot. line_settings : dict, optional Additional settings for plotting lines. Returns ------- figures : objects like :class:`matplotlib.figure.Figure` The figure created here. axes : object(s) like :class:`matplotlib.axes.Axes` The axes created here. """ if class_data is None: return _scatter_1d_flat_no_class(data, scaler=scaler, add_average=add_average, add_lines=add_lines, cmap_lines=cmap_lines, line_settings=line_settings, scatter_settings=scatter_settings) return _scatter_1d_flat_class(data, class_data, split_class=split_class, class_names=class_names, scaler=scaler, cmap_class=cmap_class, add_lines=add_lines, add_average=add_average, line_settings=line_settings, scatter_settings=scatter_settings)
[docs]def _get_settings_if_empty(settings): """Get settings if None are given.""" if settings is None: return {} return settings
[docs]def _scatter_1d_flat_no_class(data, scaler=None, add_average=False, add_lines=False, cmap_lines=None, scatter_settings=None, line_settings=None): """Make a flat plot of several variables. Here, the points on the x-axis are the variables, while the y-values are points for each data series. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we are plotting. scaler : callable, optional A function that can be used to scale the variables. add_average : boolean, optional If True, we will show the averages for each variable. add_lines : boolean, optional If True, we will show lines for each "measurement". cmap_lines : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for lines. scatter_settings : dict, optional Additional settings for the scatter plot. line_settings : dict, optional Additional settings for plotting lines. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure containing the plot. axi : object like :class:`matplotlib.axes.Axes` The axis containing the plot. """ fig, axi = plt.subplots(constrained_layout=True) variables = data.columns axi.set_xticks(range(len(variables))) axi.set_xticklabels(variables, rotation='vertical') yvalues = [] xvalues = [] if scaler is not None: axi.set_ylabel('Scaled values') else: axi.set_ylabel('Values') for i, variable in enumerate(variables): yval = data[variable] if scaler is not None: yval = scaler(yval) yvalues.append(yval) xvalues.append(np.full_like(yval, i)) yvalues = np.array(yvalues) xvalues = np.array(xvalues) line_kw = _get_settings_if_empty(line_settings) if add_lines: colors = generate_colors(len(yvalues[0, :]), cmap=cmap_lines) lines = axi.plot(xvalues, yvalues, zorder=1, **line_kw) for line, color in zip(lines, colors): line.set_color(color) scatter_kw = _get_settings_if_empty(scatter_settings) axi.scatter(xvalues, yvalues, zorder=2, **scatter_kw) if add_average: avg = np.average(yvalues, axis=1) scat = axi.scatter(range(len(avg)), avg, zorder=3, marker='X') axi.plot(range(len(avg)), avg, color=scat.get_facecolors()[0]) return fig, axi
[docs]def _scatter_1d_flat_class(data, class_data, class_names=None, scaler=None, add_lines=False, add_average=False, cmap_class=None, split_class=False, scatter_settings=None, line_settings=None): """Make a flat plot of several variables. Here, the points on the x-axis are the variables, while the y-values are points for each data series. The class information is used for coloring. Parameters ---------- data : object like :class:`pandas.core.frame.DataFrame` The data we are plotting. class_data : object like :class:`pandas.core.series.Series`, optional Class information for the points (if available). class_names : dict of strings A mapping from the class data to labels/names. scalar : callable, optional A function that can be used to scale the variables. add_average : boolean, optional If True, we will show the averages for each variable. add_lines : boolean, optional If True, we will show lines for each "measurement". cmap_class : string or object like :class:`matplotlib.colors.Colormap`, optional A color map to use for classes. split_class : boolean, optional If True, the plot with class information will be split into one plot for each class. scatter_settings : dict, optional Additional settings for the scatter plot. line_settings : dict, optional Additional settings for plotting lines. Returns ------- figures : objects like :class:`matplotlib.figure.Figure` The figure created here. axes : object(s) like :class:`matplotlib.axes.Axes` The axes created here. """ color_class, color_labels, idx_class = generate_class_colors( class_data, cmap=cmap_class ) if split_class: fig, axes = plt.subplots( nrows=1, ncols=len(idx_class), constrained_layout=True, sharex=True, sharey=True, ) all_axes = axes.flatten() axes = {class_id: all_axes[i] for i, class_id in enumerate(idx_class)} else: fig, axi = plt.subplots(constrained_layout=True) all_axes = [axi] axes = {class_id: axi for class_id in idx_class} variables = data.columns for _, axi in axes.items(): axi.set_xticks(range(len(variables))) axi.set_xticklabels(variables, rotation='vertical') yvalues = {i: [] for i in idx_class} xvalues = {i: [] for i in idx_class} for i, variable in enumerate(variables): yval = data[variable] if scaler is not None: yval = scaler(yval) for class_id, idx in idx_class.items(): yvali = yval[idx] xpos = np.full_like(yvali, i) yvalues[class_id].append(yvali) xvalues[class_id].append(xpos) line_kw = _get_settings_if_empty(line_settings) if add_lines: for class_id in idx_class: axes[class_id].plot( xvalues[class_id], yvalues[class_id], color=color_class[class_id], zorder=1, **line_kw, ) scatter_kw = _get_settings_if_empty(scatter_settings) for class_id in idx_class: axes[class_id].scatter( xvalues[class_id], yvalues[class_id], color=color_class[class_id], zorder=2, **scatter_kw, ) if add_average: for class_id in idx_class: avg = np.average(yvalues[class_id], axis=1) axes[class_id].scatter( range(len(avg)), avg, zorder=3, marker='X', color=color_class[class_id], edgecolor='black', ) create_scatter_legend( all_axes[0], color_labels, class_names, show=True, **scatter_kw ) return fig, all_axes