Source code for plotting.marginals

import copy
import logging
import os

import numpy as num
from arviz import plot_density
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
from pyrocko.cake_plot import str_to_mpl_color as scolor
from pyrocko.plot import AutoScaler, mpl_graph_color, mpl_papersize, nice_value

from beat import utility
from beat.config import bem_mode_str, dist_vars, geometry_mode_str
from beat.defaults import hypername
from beat.heart import defaults
from beat.models import Stage, load_stage

from .common import (
    format_axes,
    get_result_point,
    get_transform,
    histplot_op,
    kde2plot,
    plot_exists,
    save_figs,
)

logger = logging.getLogger("plotting.marginals")


[docs] def unify_tick_intervals(axs, varnames, ntickmarks_max=5, axis="x"): """ Take figure axes objects and determine unit ranges between common unit classes (see utility.grouped_vars). Assures that the number of increments is not larger than ntickmarks_max. Will thus overwrite Returns ------- dict : with types_sets keys and (min_range, max_range) as values """ unities = {} for setname in utility.unit_sets.keys(): unities[setname] = [num.inf, -num.inf] def extract_type_range(ax, varname, unities): for setname, ranges in unities.items(): if axis == "x": varrange = num.diff(ax.get_xlim()) elif axis == "y": varrange = num.diff(ax.get_ylim()) else: raise ValueError('Only "x" or "y" allowed!') tset = utility.unit_sets[setname] min_range, max_range = ranges if varname in tset: new_ranges = copy.deepcopy(ranges) if varrange < min_range: new_ranges[0] = varrange if varrange > max_range: new_ranges[1] = varrange unities[setname] = new_ranges for ax, varname in zip(axs.ravel("F"), varnames): extract_type_range(ax, varname, unities) for setname, ranges in unities.items(): min_range, max_range = ranges max_range_frac = max_range / ntickmarks_max if max_range_frac > min_range: logger.debug( "Range difference between min and max for %s is large!" " Extending min_range to %f" % (setname, max_range_frac) ) unities[setname] = [max_range_frac, max_range] return unities
def apply_unified_axis( axs, varnames, unities, axis="x", ntickmarks_max=3, scale_factor=2 / 3 ): naxs = axs.size nvars = len(varnames) if naxs != nvars: logger.warning( "Inconsistenet number of Axes: %i and variables: %i!" % (naxs, nvars) ) for ax, v in zip(axs.ravel("F"), varnames): if v in utility.grouped_vars: for setname, varrange in unities.items(): if v in utility.unit_sets[setname]: inc = nice_value(varrange[0] * scale_factor) autos = AutoScaler(inc=inc, snap="on", approx_ticks=ntickmarks_max) if axis == "x": min, max = ax.get_xlim() elif axis == "y": min, max = ax.get_ylim() min, max, sinc = autos.make_scale( (min, max), override_mode="min-max" ) # check physical bounds if passed truncate phys_min, phys_max = defaults[v].physical_bounds if min < phys_min: min = phys_min if max > phys_max: max = phys_max if axis == "x": ax.set_xlim((min, max)) elif axis == "y": ax.set_ylim((min, max)) ticks = num.arange(min, max + inc, inc).tolist() if axis == "x": ax.xaxis.set_ticks(ticks) elif axis == "y": ax.yaxis.set_ticks(ticks) else: ticker = MaxNLocator(nbins=3) if axis == "x": ax.get_xaxis().set_major_locator(ticker) elif axis == "y": ax.get_yaxis().set_major_locator(ticker)
[docs] def traceplot( trace, varnames=None, lines={}, chains=None, combined=False, grid=False, varbins=None, nbins=40, color=None, source_idxs=None, alpha=0.35, priors=None, prior_alpha=1, prior_style="--", posterior=None, plot_style="kde", prior_bounds={}, unify=True, qlist=[0.1, 99.9], kwargs={}, ): """ Plots posterior pdfs as histograms from multiple mtrace objects. Modified from pymc3. Parameters ---------- trace : result of MCMC run varnames : list of variable names Variables to be plotted, if None all variable are plotted posterior : str To mark posterior value in distribution 'max', 'min', 'mean', 'all' lines : dict Dictionary of variable name / value to be overplotted as vertical lines to the posteriors and horizontal lines on sample values e.g. mean of posteriors, true values of a simulation chains : int or list of ints chain indexes to select from the trace combined : bool Flag for combining multiple chains into a single chain. If False (default), chains will be plotted separately. source_idxs : list array like, indexes to sources to plot marginals grid : bool Flag for adding gridlines to histogram. Defaults to True. varbins : list of arrays List containing the binning arrays for the variables, if None they will be created. nbins : int Number of bins for each histogram color : tuple mpl color tuple alpha : float Alpha value for plot line. Defaults to 0.35. unify : bool If true axis units that belong to one group e.g. [km] will have common axis increments kwargs : dict for histplot op qlist : list of quantiles to plot. Default: (almost all, 0.01, 99.99) Returns ------- ax : matplotlib axes """ fontsize = 10 ntickmarks_max = kwargs.pop("ntickmarks_max", 3) scale_factor = kwargs.pop("scale_factor", 2 / 3) num.set_printoptions(precision=3) def make_bins(data, nbins=40, qlist=None): d = data.ravel() if qlist is not None: qu = num.percentile(d, q=qlist) mind, maxd = qu[0], qu[-1] else: mind = d.min() maxd = d.max() return num.linspace(mind, maxd, nbins) def remove_var(varnames, varname): idx = varnames.index(varname) varnames.pop(idx) if varnames is None: varnames = [name for name in trace.varnames if not name.endswith("_")] if "geo_like" in varnames: remove_var(varnames, varname="geo_like") if "seis_like" in varnames: remove_var(varnames, varname="seis_like") if posterior != "None": llk = trace.get_values("like", combine=combined, chains=chains, squeeze=False) llk = num.squeeze(llk[0]) llk = num.atleast_2d(llk) posterior_idxs = utility.get_fit_indexes(llk) colors = { "mean": scolor("orange1"), "min": scolor("butter1"), "max": scolor("scarletred2"), } n = nvar = len(varnames) if n == 1 and source_idxs is None: raise IOError( "If only single variable is selected source_idxs need to be specified!" ) elif n == 1 and len(source_idxs) > 1: n = len(source_idxs) logger.info("Plotting of patches in panels ...") varnames = varnames * n else: logger.info("Plotting variables in panels ...") if varbins is None: make_bins_flag = True varbins = [] else: make_bins_flag = False input_color = copy.deepcopy(color) backup_source_idxs = copy.deepcopy(source_idxs) # subfigure handling nrowtotal = int(num.ceil(n / 2.0)) ncol = 2 nrow_max = 4 nplots_page_max = nrow_max * ncol n_subplots_total = nrowtotal * ncol ntotal_figs, nrest_subplots = utility.mod_i(n_subplots_total, nplots_page_max) nsubplots_page = [nplots_page_max for _ in range(ntotal_figs)] if nrest_subplots: nsubplots_page.append(nrest_subplots) figs = [] fig_axs = [] var_idx = 0 varname_page_idx = 0 for nsubplots in nsubplots_page: width, height = mpl_papersize("a4", "portrait") height_subplot = height / nrow_max nrow = int(num.ceil(nsubplots / ncol)) fig, axs = plt.subplots(nrow, ncol, figsize=(width, height_subplot * nrow)) axs = num.atleast_2d(axs) for i in range(nsubplots): coli, rowi = utility.mod_i(i, nrow) ax = axs[rowi, coli] if var_idx > n - 1: try: fig.delaxes(ax) except KeyError: pass else: if nvar == 1: source_idxs = [backup_source_idxs[i]] v = varnames[var_idx] var_idx += 1 color = copy.deepcopy(input_color) for d in trace.get_values( v, combine=combined, chains=chains, squeeze=False ): plot_name, transform = get_transform(v) d = transform(d) # iterate over columns in case varsize > 1 if v in dist_vars: if source_idxs is None: source_idx_step = int(num.floor(d.shape[1] / 6)) logger.info( "No patches defined using 1 every %i!", source_idx_step ) source_idxs = num.arange( 0, d.shape[1], source_idx_step ).tolist() logger.info( "Plotting patches: %s" % utility.list2string(source_idxs) ) selected = [] for s_idx in source_idxs: try: if isinstance(s_idx, slice): d_sel = num.atleast_2d(d.T[s_idx].mean(0)) else: d_sel = num.atleast_2d(d.T[s_idx]) except IndexError: raise IndexError( "One or several patches do not exist! " "Patch idxs: %s" % utility.list2string([s_idx]) ) selected.append(d_sel) selected = num.vstack(selected) else: selected = num.atleast_2d(d.T) nsources = selected.shape[0] logger.debug("Number of sources: %i" % nsources) for isource, e in enumerate(selected): e = num.atleast_2d(e) if make_bins_flag: varbin = make_bins(e, nbins=nbins, qlist=qlist) varbins.append(varbin) else: varbin = varbins[i] if lines: if v in lines: reference = lines[v] else: reference = None else: reference = None if color is None: if nsources == 1: pcolor = "black" else: pcolor = mpl_graph_color(isource) else: pcolor = color if plot_style == "kde": plot_density( e, shade=alpha, ax=ax, colors=[pcolor], backend="matplotlib", backend_kwargs={ "linewidth": 1.0, }, ) ax.relim() ax.autoscale(tight=False) ax.set_ylim(0) xax = ax.get_xaxis() # axs[rowi, coli].set_ylim([0, e.max()]) xticker = MaxNLocator(nbins=5) xax.set_major_locator(xticker) elif plot_style in ["pdf", "cdf"]: kwargs["label"] = source_idxs # following determine quantile annotations in cdf kwargs["nsources"] = nsources kwargs["isource"] = isource if plot_style == "cdf": kwargs["cumulative"] = True else: kwargs["cumulative"] = False histplot_op( ax, e, reference=reference, bins=varbin, alpha=alpha, color=pcolor, qlist=qlist, kwargs=kwargs, ) else: raise NotImplementedError( 'Plot style "%s" not implemented' % plot_style ) plot_unit = defaults[hypername(plot_name)].unit try: param = prior_bounds[v] if v in dist_vars: try: # variable bounds lower = param.lower[tuple(source_idxs)] upper = param.upper[tuple(source_idxs)] except IndexError: lower, upper = param.lower, param.upper title = "{} {}".format(v, plot_unit) else: lower = num.array2string(param.lower, separator=",")[ 1:-1 ] upper = num.array2string(param.upper, separator=",")[ 1:-1 ] title = "{} {} \npriors: ({}; {})".format( plot_name, plot_unit, lower, upper, ) except KeyError: try: title = "{} {}".format(plot_name, float(lines[v])) except KeyError: title = "{} {}".format(plot_name, plot_unit) axs[rowi, coli].set_xlabel(title, fontsize=fontsize) if nvar == 1: axs[rowi, coli].set_title( "Patch %s" % utility.list2string(source_idxs), loc="left", fontsize=fontsize, ) ax.grid(grid) ax.get_yaxis().set_visible(False) format_axes(axs[rowi, coli]) ax.tick_params(axis="x", labelsize=fontsize) if lines: try: for line in lines[v]: ax.axvline(x=line, color="white", lw=1.0) ax.axvline( x=line, color="black", linestyle="dashed", lw=1.0, ) except KeyError: pass if posterior != "None": if posterior == "all": for k, idx in posterior_idxs.items(): ax.axvline(x=e[:, idx], color=colors[k], lw=1.0) else: idx = posterior_idxs[posterior] ax.axvline(x=e[:, idx], color=pcolor, lw=1.0) if unify: page_varnames = varnames[varname_page_idx : varname_page_idx + nsubplots] unities = unify_tick_intervals( axs, page_varnames, ntickmarks_max=ntickmarks_max, axis="x" ) apply_unified_axis( axs, page_varnames, unities, axis="x", scale_factor=scale_factor ) varname_page_idx += nsubplots fig.subplots_adjust(wspace=0.05, hspace=0.5) fig.tight_layout() figs.append(fig) fig_axs.append(axs) return figs, fig_axs, varbins
[docs] def correlation_plot( mtrace, varnames=None, figsize=None, cmap=None, grid=200, point=None, point_style=".", point_color="white", point_size="8", ): """ Plot 2d marginals (with kernel density estimation) showing the correlations of the model parameters. Parameters ---------- mtrace : :class:`.base.MutliTrace` Mutlitrace instance containing the sampling results varnames : list of variable names Variables to be plotted, if None all variable are plotted figsize : figure size tuple If None, size is (12, num of variables * 2) inch cmap : matplotlib colormap grid : resolution of kernel density estimation point : dict Dictionary of variable name / value to be overplotted as marker to the posteriors e.g. mean of posteriors, true values of a simulation point_style : str style of marker according to matplotlib conventions point_color : str or tuple of 3 color according to matplotlib convention point_size : str marker size according to matplotlib conventions Returns ------- fig : figure object axs : subplot axis handles """ if varnames is None: varnames = mtrace.varnames nvar = len(varnames) if figsize is None: figsize = mpl_papersize("a4", "landscape") fig, axs = plt.subplots( sharey="row", sharex="col", nrows=nvar - 1, ncols=nvar - 1, figsize=figsize ) d = dict() for var in varnames: plot_name, transform = get_transform(var) vals = transform(mtrace.get_values(var, combine=True, squeeze=True)) _, nvar_elements = vals.shape if nvar_elements > 1: raise ValueError( "Correlation plot can only be displayed for variables " " with size 1! %s is %i! " % (var, nvar_elements) ) d[var] = vals for i_k in range(nvar - 1): varname_a = varnames[i_k] a = d[varname_a] for i_l in range(i_k + 1, nvar): ax = axs[i_l - 1, i_k] varname_b = varnames[i_l] logger.debug("%s, %s" % (varname_a, varname_b)) b = d[varname_b] kde2plot(a, b, grid=grid, ax=ax, cmap=cmap, aspect="auto") if point is not None: ax.plot( point[varnames[i_k]], point[varnames[i_l]], color=point_color, marker=point_style, markersize=point_size, ) ax.tick_params(direction="in") if i_k == 0: ax.set_ylabel(varname_b) axs[i_l - 1, i_k].set_xlabel(varname_a) for i_k in range(nvar - 1): for i_l in range(i_k): fig.delaxes(axs[i_l, i_k]) fig.tight_layout() fig.subplots_adjust(wspace=0.05, hspace=0.05) return fig, axs
[docs] def correlation_plot_hist( mtrace, varnames=None, figsize=None, hist_color=None, cmap=None, grid=50, chains=None, ntickmarks=2, point=None, point_style=".", point_color="red", point_size=4, alpha=0.35, unify=True, ): """ Plot 2d marginals (with kernel density estimation) showing the correlations of the model parameters. In the main diagonal is shown the parameter histograms. Parameters ---------- mtrace : :class:`pymc.backends.base.MultiTrace` Mutlitrace instance containing the sampling results varnames : list of variable names Variables to be plotted, if None all variable are plotted figsize : figure size tuple If None, size is (12, num of variables * 2) inch cmap : matplotlib colormap hist_color : str or tuple of 3 color according to matplotlib convention grid : resolution of kernel density estimation chains : int or list of ints chain indexes to select from the trace ntickmarks : int number of ticks at the axis labels point : dict Dictionary of variable name / value to be overplotted as marker to the posteriors e.g. mean of posteriors, true values of a simulation point_style : str style of marker according to matplotlib conventions point_color : str or tuple of 3 color according to matplotlib convention point_size : str marker size according to matplotlib conventions unify: bool If true axis units that belong to one group e.g. [km] will have common axis increments Returns ------- fig : figure object axs : subplot axis handles """ fontsize = 9 ntickmarks_max = 2 label_pad = 25 logger.info("Drawing correlation figure ...") logger.warning("Does NOT seperate parameters correctly for Mixed Type Setups!") if varnames is None: varnames = mtrace.varnames nvar = len(varnames) if figsize is None: if nvar < 5: figsize = mpl_papersize("a5", "landscape") else: figsize = mpl_papersize("a4", "landscape") d = dict() for var in varnames: _, transform = get_transform(var) vals = transform( mtrace.get_values(var, chains=chains, combine=True, squeeze=True) ) logger.info("Getting data for `%s` from sampled trace." % var) try: _, nvar_elements = vals.shape except ValueError: # for variables woth dim=1 nvar_elements = 1 vals = num.atleast_2d(vals).T d[var] = vals figs = [] axes = [] for source_i in range(nvar_elements): logger.info("for variables of source %i ..." % source_i) hist_ylims = [] fig, axs = plt.subplots(nrows=nvar, ncols=nvar, figsize=figsize) if hist_color is None: if nvar_elements == 1: pcolor = "orange" else: pcolor = mpl_graph_color(source_i) else: pcolor = hist_color for i_k in range(nvar): v_namea = varnames[i_k] a = d[v_namea][:, source_i] for i_l in range(i_k, nvar): ax = axs[i_l, i_k] v_nameb = varnames[i_l] plot_name_a, transform_a = get_transform(v_namea) plot_name_b, transform_b = get_transform(v_nameb) logger.debug("%s, %s" % (v_namea, v_nameb)) if i_l == i_k: if point is not None: if v_namea in point.keys(): reference = transform_a(point[v_namea][source_i]) ax.axvline( x=reference, color=point_color, lw=point_size / 4.0 ) else: reference = None else: reference = None histplot_op( ax, num.atleast_2d(a), alpha=alpha, color=pcolor, tstd=0.0, reference=reference, ) ax.get_yaxis().set_visible(False) format_axes(ax) xticks = ax.get_xticks() xlim = ax.get_xlim() hist_ylims.append(ax.get_ylim()) else: b = d[v_nameb][:, source_i] kde2plot(a, b, grid=grid, ax=ax, cmap=cmap, aspect="auto") bmin = b.min() bmax = b.max() if point is not None: if v_namea and v_nameb in point.keys(): value_vara = transform_a(point[v_namea][source_i]) value_varb = transform_b(point[v_nameb][source_i]) ax.plot( value_vara, value_varb, color=point_color, marker=point_style, markersize=point_size, ) bmin = num.minimum(bmin, value_varb) bmax = num.maximum(bmax, value_varb) yticker = MaxNLocator(nbins=ntickmarks) ax.set_xticks(xticks) ax.set_xlim(xlim) yax = ax.get_yaxis() yax.set_major_locator(yticker) if i_l != nvar - 1: ax.get_xaxis().set_ticklabels([]) if i_k == 0: ax.set_ylabel( plot_name_b + "\n " + defaults[hypername(plot_name_b)].unit, fontsize=fontsize, ) if utility.is_odd(i_l): ax.tick_params(axis="y", pad=label_pad) else: ax.get_yaxis().set_ticklabels([]) ax.tick_params(axis="both", direction="in", labelsize=fontsize) try: # matplotlib version issue workaround ax.tick_params(axis="both", labelrotation=50.0) except Exception: ax.set_xticklabels(axs[i_l, i_k].get_xticklabels(), rotation=50) ax.set_yticklabels(axs[i_l, i_k].get_yticklabels(), rotation=50) if utility.is_odd(i_k): ax.tick_params(axis="x", pad=label_pad) # put transformed varname back to varnames for unification # varnames[k] = plot_name_a ax.set_xlabel( plot_name_a + "\n " + defaults[hypername(plot_name_a)].unit, fontsize=fontsize, ) if unify: varnames_repeat_x = [ var_reap for varname in varnames for var_reap in (varname,) * nvar ] varnames_repeat_y = varnames * nvar unitiesx = unify_tick_intervals( axs, varnames_repeat_x, ntickmarks_max=ntickmarks_max, axis="x" ) apply_unified_axis( axs, varnames_repeat_x, unitiesx, axis="x", scale_factor=1.0, ntickmarks_max=ntickmarks_max, ) apply_unified_axis( axs, varnames_repeat_y, unitiesx, axis="y", scale_factor=1.0, ntickmarks_max=ntickmarks_max, ) for i_k in range(nvar): if unify: # reset histogram ylims after unify axs[i_k, i_k].set_ylim(hist_ylims[i_k]) for i_l in range(i_k): fig.delaxes(axs[i_l, i_k]) fig.tight_layout() fig.subplots_adjust(wspace=0.05, hspace=0.05) figs.append(fig) axes.append(axs) return figs, axes
[docs] def draw_posteriors(problem, plot_options): """ Identify which stage is the last complete stage and plot posteriors. """ plot_style_choices = ["pdf", "cdf", "kde", "local"] hypers = utility.check_hyper_flag(problem) po = plot_options if po.plot_projection in plot_style_choices: if po.plot_projection == "local": plot_style = "pdf" nbins = 40 else: plot_style = po.plot_projection nbins = 200 else: raise ValueError( "Supported plot-projections are: %s" % utility.list2string(plot_style_choices) ) logger.info('Plotting "%s"' % plot_style) stage = Stage( homepath=problem.outfolder, backend=problem.config.sampler_config.backend ) pc = problem.config.problem_config list_indexes = stage.handler.get_stage_indexes(po.load_stage) if hypers: varnames = problem.hypernames + ["like"] else: varnames = ( problem.varnames + problem.hypernames + problem.hierarchicalnames + ["like"] ) if len(po.varnames) > 0: varnames = po.varnames logger.info("Plotting variables: %s" % (", ".join((v for v in varnames)))) figs = [] for s in list_indexes: if po.source_idxs: sidxs = utility.list2string(po.source_idxs, fill="_") else: sidxs = "" outpath = os.path.join( problem.outfolder, po.figure_dir, "stage_%i_%s_%s_%s" % (s, sidxs, po.post_llk, plot_style), ) if plot_exists(outpath, po.outformat, po.force): return logger.info("plotting stage: %s" % stage.handler.stage_path(s)) stage.load_results( varnames=problem.varnames, model=problem.model, stage_number=s, load="trace", chains=[-1], ) prior_bounds = {} prior_bounds.update(**pc.hyperparameters) prior_bounds.update(**pc.hierarchicals) prior_bounds.update(**pc.priors) figs, _, _ = traceplot( stage.mtrace, varnames=varnames, chains=None, combined=True, source_idxs=po.source_idxs, plot_style=plot_style, lines=po.reference, posterior=po.post_llk, prior_bounds=prior_bounds, nbins=nbins, ) save_figs(figs, outpath, po.outformat, po.dpi)
[docs] def draw_correlation_hist(problem, plot_options): """ Draw parameter correlation plot and histograms for a model result ensemble. Only feasible for 'geometry' problem. """ po = plot_options mode = problem.config.problem_config.mode if mode not in [geometry_mode_str, bem_mode_str]: raise NotImplementedError(f"The correlation plot is not implemented for {mode}") assert po.load_stage != 0 hypers = utility.check_hyper_flag(problem) if hypers: varnames = problem.hypernames else: varnames = list(problem.varnames) if len(po.varnames) > 0: varnames = po.varnames logger.info("Plotting variables: %s" % (", ".join((v for v in varnames)))) if len(varnames) < 2: raise TypeError( "Need at least two parameters to compare!" "Found only %i variables! " % len(varnames) ) stage = load_stage(problem, stage_number=po.load_stage, load="trace", chains=[-1]) if not po.reference: reference = get_result_point(stage.mtrace, po.post_llk) llk_str = po.post_llk else: reference = po.reference llk_str = "ref" outpath = os.path.join( problem.outfolder, po.figure_dir, "corr_hist_%s_%s" % (stage.number, llk_str), ) if plot_exists(outpath, po.outformat, po.force): return figs, _ = correlation_plot_hist( mtrace=stage.mtrace, varnames=varnames, cmap=plt.cm.gist_earth_r, chains=None, point=reference, point_size=6, point_color="red", ) save_figs(figs, outpath, po.outformat, po.dpi)