Source code for psynlig.heatmap

# Copyright (c) 2020, Anders Lervik.
# Distributed under the MIT License. See LICENSE for more info.
"""A module defining helper methods for creating a heat map."""
from matplotlib import pyplot as plt
from matplotlib.ticker import StrMethodFormatter
from matplotlib.patches import Circle, Rectangle
import numpy as np
from .common import set_up_fig_and_axis, get_figure_kwargs


[docs]def create_bubbles(data, img, axi): """Create bubbles for a heat map. Parameters ---------- data : object like :class:`numpy.ndarray` A 2D numpy array of shape (N, M). img : object like :class:`matplotlib.image.AxesImage` The heat map image we have generated. axi : object like :class:`matplotlib.axes.Axes` The axis to add the bubbles to. """ vals = img.get_array() for i in range(data.shape[0]): for j in range(data.shape[1]): value = img.norm(vals[i, j]) radius = np.abs(vals[i, j]) * 0.5 * 0.9 color = img.cmap(value) if i % 2 == 0: rect = Rectangle((j-0.5, i-0.5), 1, 1, color='0.8') else: rect = Rectangle((j-0.5, i-0.5), 1, 1, color='0.9') axi.add_artist(rect) circle = Circle((j, i), radius=radius, color=color) axi.add_artist(circle) img.set_visible(False)
[docs]def heatmap(data, row_labels, col_labels, axi=None, fig=None, cbar_kw=None, cbarlabel='', bubble=False, **kwargs): """Create a heat map from a numpy array and two lists of labels. Parameters ---------- data : object like :class:`numpy.ndarray` A 2D numpy array of shape (N, M). row_labels : list of strings A list or array of length N with the labels for the rows. col_labels : list of strings A list or array of length M with the labels for the columns. axi : object like :class:`matplotlib.axes.Axes`, optional An axis to plot the heat map. If not provided, a new axis will be created. fig : object like :class:`matplotlib.figure.Figure`, optional The figure where the axes resides in. If given, tight layout will be applied. cbar_kw : dict, optional A dictionary with arguments to the creation of the color bar. cbarlabel : string, optional The label for the color bar. bubble : boolean, optional If True, we will draw bubbles indicating the size of the given data points. **kwargs : dict, optional Additional arguments for drawing the heat map. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure in which the heatmap is plotted. axi : object like :class:`matplotlib.axes.Axes` The axis to which the heatmap is added. img : object like :class:`matplotlib.image.AxesImage` The generated heat map. cbar : object like :class:`matplotlib.colorbar.Colorbar` The color bar created for the heat map. """ fig, axi = set_up_fig_and_axis(fig, axi) # Plot the heatmap: img = axi.imshow(data, **kwargs) # Check if this is a bubble map: if bubble: create_bubbles(data, img, axi) # Create colorbars: if cbar_kw is None: cbar_kw = {} cbar = axi.figure.colorbar(img, ax=axi, **cbar_kw) cbar.ax.set_ylabel(cbarlabel, rotation=-90, va='bottom') # Show ticks using the provided labels: axi.set_xticks(np.arange(data.shape[1])) axi.set_xticklabels( col_labels, rotation=-30, horizontalalignment='right', rotation_mode='anchor' ) axi.set_yticks(np.arange(data.shape[0])) axi.set_yticklabels(row_labels) # Labels on top: axi.tick_params( top=True, bottom=False, labeltop=True, labelbottom=False ) # Hide spines off: for _, spine in axi.spines.items(): spine.set_visible(False) # Add grid: axi.grid(False) axi.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) axi.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) if bubble: axi.grid(which='minor', color='white', linestyle='-', linewidth=3) axi.tick_params(which='minor', bottom=False, left=False) axi.tick_params(which='major', top=False, left=False) else: axi.grid(which='minor', color='white', linestyle='-', linewidth=3) axi.tick_params(which='minor', bottom=False, left=False) if fig is not None: fig.tight_layout() return fig, axi, img, cbar
[docs]def annotate_heatmap(img, data=None, val_fmt='{x:.2f}', textcolors=None, **kwargs): """Annotate a heatmap with values. Parameters ---------- img : object like :class:`matplotlib.image.AxesImage` The heat map image to be labeled. data : object like :class:`numpy.ndarray`, optional Data used to annotate the heat map. If not given, the data in the heat map image (``img``) is used. val_fmt : string, optional The format of the annotations inside the heat map. textcolors : list of strings, optional Colors used for the text. The number of colors provided defines a binning for the data values, and values are colored with the corresponding color. If no colors are provided, all are colored black. **kwargs : dict, optional Extra arguments used for creating text labels. """ if data is None: data = img.get_array() # Create arguments for text: textkw = kwargs.copy() textkw.update( { 'horizontalalignment': 'center', 'verticalalignment': 'center', } ) # Get the formatter: formatter = StrMethodFormatter(val_fmt) if textcolors is None: textcolors = ['black'] texts = [] bins = np.linspace(0, 1, len(textcolors) + 1) for i in range(data.shape[0]): for j in range(data.shape[1]): val = img.norm(data[i, j]) idx = np.digitize(val, bins, right=True) idx = max(idx - 1, 0) textkw.update(color=textcolors[idx]) text = img.axes.text(j, i, formatter(data[i, j], None), **textkw) texts.append(text) return texts
[docs]def plot_correlation_heatmap(data, val_fmt='{x:.2f}', bubble=False, annotate=True, textcolors=None, **kwargs): """Plot a heat map to investigate correlations. Parameters ---------- data : object like :class:`pandas.DataFrame` The data we will generate a heat correlation map from. val_fmt : string, optional The format of the annotations inside the heat map. bubble : optional, boolean If True, we will draw bubbles to indicate the size of the given data points. annotate : boolean, optional If True, we will annotate the plot with values. textcolors : list of strings, optional Colors used for the text. The number of colors provided defines a binning for the data values, and values are colored with the corresponding color. If no colors are provided, all are colored black. **kwargs : dict, optional Arguments used for drawing the heat map. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure in which the heatmap is plotted. ax1 : object like :class:`matplotlib.axes.Axes` The axis to which the heat map is added. """ corr = data.corr(method='pearson') fig1, ax1 = plot_annotated_heatmap( corr, data.columns, data.columns, cbarlabel='Pearson correlation coefficient', val_fmt=val_fmt, bubble=bubble, textcolors=textcolors, annotate=annotate, **kwargs ) return fig1, ax1
[docs]def plot_annotated_heatmap(data, row_labels, col_labels, cbarlabel='', val_fmt='{x:.2f}', textcolors=None, bubble=False, annotate=True, **kwargs): """Plot a heat map to investigate correlations. Parameters ---------- data : object like :class:`numpy.ndarray` A 2D numpy array of shape (N, M). row_labels : list of strings A list or array of length N with the labels for the rows. col_labels : list of strings A list or array of length M with the labels for the columns. cbarlabel : string, optional The label for the color bar. val_fmt : string, optional The format of the annotations inside the heat map. textcolors : list of strings, optional Colors used for the text. The number of colors provided defines a binning for the data values, and values are colored with the corresponding color. If no colors are provided, all are colored black. bubble : boolean, optional If True, we will draw bubbles to indicate the size of the given data points. annotate : boolean, optional If True, we will annotate the plot with values. **kwargs : dict, optional Arguments used for drawing the heat map. Returns ------- fig : object like :class:`matplotlib.figure.Figure` The figure in which the heatmap is plotted. ax1 : object like :class:`matplotlib.axes.Axes` The axis to which the heat map is added. """ fig_kw = get_figure_kwargs(kwargs) fig1, ax1 = plt.subplots(**fig_kw) _, _, img, _ = heatmap( data, row_labels, col_labels, axi=ax1, cbarlabel=cbarlabel, bubble=bubble, **kwargs.get('heatmap', {}), ) if annotate: annotate_heatmap( img, val_fmt=val_fmt, textcolors=textcolors, **kwargs.get('text', {}), ) return fig1, ax1