Source code for util.visualization.mean_std_plot

import matplotlib as mpl
# To facilitate plotting on a headless server
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

[docs]def plot_mean_std(x=None, arr=None, suptitle='', title='', xlabel='X', ylabel='Y', xlim=None, ylim=None): """ Plots the accuracy/loss curve over several runs with standard deviation and mean. Parameters ---------- x: numpy.ndarray contains the ticks on the x-axis arr: numpy.ndarray contains the accuracy values for each epoch per run suptitle: str title for the plot title: str sub-title for the plot xlabel: str label for the x-axis ylabel: str label for the y-axis xlim: float or None optionally specify a upper limit on the x-axis ylim: float or None optionally specify a upper limit on the y-axis Returns ------- data: numpy.ndarray Contains an RGB image of the plotted accuracy curves """ fig = plt.figure(1) arr_mean = np.mean(arr, 0) arr_std = np.std(arr, 0) arr_min = np.min(arr, 0) arr_max = np.max(arr, 0) with sns.axes_style('darkgrid'): fig.suptitle(suptitle) plt.title(title) axes = plt.gca() if ylim is not None: axes.set_ylim(ylim) if xlim is not None: axes.set_xlim(xlim) plt.xlabel(xlabel) plt.ylabel(ylabel) if x is None: plt.plot(arr_mean, '-', color='#0000b3', label='Score') plt.plot(arr_min, color='#4d4dff', linestyle='dashed', label='Min') plt.plot(arr_max, color='#4d4dff', linestyle='dashed', label='Max') axes.fill_between(np.arange(len(arr_mean)), arr_mean - arr_std, arr_mean + arr_std, color='#9999ff', alpha=0.2) else: plt.plot(x, arr_mean, '-', color='#0000b3', label='Score') plt.plot(x, arr_min, color='#4d4dff', linestyle='dashed', label='Min') plt.plot(x, arr_max, color='#4d4dff', linestyle='dashed', label='Max') axes.fill_between(np.arange(len(arr_mean)) - 1, arr_mean - arr_std, arr_mean + arr_std, color='#9999ff', alpha=0.2) plt.legend(loc='best') fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) fig.clf() plt.close() return data