import copy
import logging
import os
import numpy as num
from arviz import plot_density
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
from pyrocko.cake_plot import str_to_mpl_color as scolor
from pyrocko.plot import AutoScaler, mpl_graph_color, mpl_papersize, nice_value
from beat import utility
from beat.config import bem_mode_str, dist_vars, geometry_mode_str
from beat.defaults import hypername
from beat.heart import defaults
from beat.models import Stage, load_stage
from .common import (
format_axes,
get_result_point,
get_transform,
histplot_op,
kde2plot,
plot_exists,
save_figs,
)
logger = logging.getLogger("plotting.marginals")
[docs]
def unify_tick_intervals(axs, varnames, ntickmarks_max=5, axis="x"):
"""
Take figure axes objects and determine unit ranges between common
unit classes (see utility.grouped_vars). Assures that the number of
increments is not larger than ntickmarks_max. Will thus overwrite
Returns
-------
dict : with types_sets keys and (min_range, max_range) as values
"""
unities = {}
for setname in utility.unit_sets.keys():
unities[setname] = [num.inf, -num.inf]
def extract_type_range(ax, varname, unities):
for setname, ranges in unities.items():
if axis == "x":
varrange = num.diff(ax.get_xlim())
elif axis == "y":
varrange = num.diff(ax.get_ylim())
else:
raise ValueError('Only "x" or "y" allowed!')
tset = utility.unit_sets[setname]
min_range, max_range = ranges
if varname in tset:
new_ranges = copy.deepcopy(ranges)
if varrange < min_range:
new_ranges[0] = varrange
if varrange > max_range:
new_ranges[1] = varrange
unities[setname] = new_ranges
for ax, varname in zip(axs.ravel("F"), varnames):
extract_type_range(ax, varname, unities)
for setname, ranges in unities.items():
min_range, max_range = ranges
max_range_frac = max_range / ntickmarks_max
if max_range_frac > min_range:
logger.debug(
"Range difference between min and max for %s is large!"
" Extending min_range to %f" % (setname, max_range_frac)
)
unities[setname] = [max_range_frac, max_range]
return unities
def apply_unified_axis(
axs, varnames, unities, axis="x", ntickmarks_max=3, scale_factor=2 / 3
):
naxs = axs.size
nvars = len(varnames)
if naxs != nvars:
logger.warning(
"Inconsistenet number of Axes: %i and variables: %i!" % (naxs, nvars)
)
for ax, v in zip(axs.ravel("F"), varnames):
if v in utility.grouped_vars:
for setname, varrange in unities.items():
if v in utility.unit_sets[setname]:
inc = nice_value(varrange[0] * scale_factor)
autos = AutoScaler(inc=inc, snap="on", approx_ticks=ntickmarks_max)
if axis == "x":
min, max = ax.get_xlim()
elif axis == "y":
min, max = ax.get_ylim()
min, max, sinc = autos.make_scale(
(min, max), override_mode="min-max"
)
# check physical bounds if passed truncate
phys_min, phys_max = defaults[v].physical_bounds
if min < phys_min:
min = phys_min
if max > phys_max:
max = phys_max
if axis == "x":
ax.set_xlim((min, max))
elif axis == "y":
ax.set_ylim((min, max))
ticks = num.arange(min, max + inc, inc).tolist()
if axis == "x":
ax.xaxis.set_ticks(ticks)
elif axis == "y":
ax.yaxis.set_ticks(ticks)
else:
ticker = MaxNLocator(nbins=3)
if axis == "x":
ax.get_xaxis().set_major_locator(ticker)
elif axis == "y":
ax.get_yaxis().set_major_locator(ticker)
[docs]
def traceplot(
trace,
varnames=None,
lines={},
chains=None,
combined=False,
grid=False,
varbins=None,
nbins=40,
color=None,
source_idxs=None,
alpha=0.35,
priors=None,
prior_alpha=1,
prior_style="--",
posterior=None,
plot_style="kde",
prior_bounds={},
unify=True,
qlist=[0.1, 99.9],
kwargs={},
):
"""
Plots posterior pdfs as histograms from multiple mtrace objects.
Modified from pymc3.
Parameters
----------
trace : result of MCMC run
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
posterior : str
To mark posterior value in distribution 'max', 'min', 'mean', 'all'
lines : dict
Dictionary of variable name / value to be overplotted as vertical
lines to the posteriors and horizontal lines on sample values
e.g. mean of posteriors, true values of a simulation
chains : int or list of ints
chain indexes to select from the trace
combined : bool
Flag for combining multiple chains into a single chain. If False
(default), chains will be plotted separately.
source_idxs : list
array like, indexes to sources to plot marginals
grid : bool
Flag for adding gridlines to histogram. Defaults to True.
varbins : list of arrays
List containing the binning arrays for the variables, if None they will
be created.
nbins : int
Number of bins for each histogram
color : tuple
mpl color tuple
alpha : float
Alpha value for plot line. Defaults to 0.35.
unify : bool
If true axis units that belong to one group e.g. [km] will
have common axis increments
kwargs : dict
for histplot op
qlist : list
of quantiles to plot. Default: (almost all, 0.01, 99.99)
Returns
-------
ax : matplotlib axes
"""
fontsize = 10
ntickmarks_max = kwargs.pop("ntickmarks_max", 3)
scale_factor = kwargs.pop("scale_factor", 2 / 3)
num.set_printoptions(precision=3)
def make_bins(data, nbins=40, qlist=None):
d = data.ravel()
if qlist is not None:
qu = num.percentile(d, q=qlist)
mind, maxd = qu[0], qu[-1]
else:
mind = d.min()
maxd = d.max()
return num.linspace(mind, maxd, nbins)
def remove_var(varnames, varname):
idx = varnames.index(varname)
varnames.pop(idx)
if varnames is None:
varnames = [name for name in trace.varnames if not name.endswith("_")]
if "geo_like" in varnames:
remove_var(varnames, varname="geo_like")
if "seis_like" in varnames:
remove_var(varnames, varname="seis_like")
if posterior != "None":
llk = trace.get_values("like", combine=combined, chains=chains, squeeze=False)
llk = num.squeeze(llk[0])
llk = num.atleast_2d(llk)
posterior_idxs = utility.get_fit_indexes(llk)
colors = {
"mean": scolor("orange1"),
"min": scolor("butter1"),
"max": scolor("scarletred2"),
}
n = nvar = len(varnames)
if n == 1 and source_idxs is None:
raise IOError(
"If only single variable is selected source_idxs need to be specified!"
)
elif n == 1 and len(source_idxs) > 1:
n = len(source_idxs)
logger.info("Plotting of patches in panels ...")
varnames = varnames * n
else:
logger.info("Plotting variables in panels ...")
if varbins is None:
make_bins_flag = True
varbins = []
else:
make_bins_flag = False
input_color = copy.deepcopy(color)
backup_source_idxs = copy.deepcopy(source_idxs)
# subfigure handling
nrowtotal = int(num.ceil(n / 2.0))
ncol = 2
nrow_max = 4
nplots_page_max = nrow_max * ncol
n_subplots_total = nrowtotal * ncol
ntotal_figs, nrest_subplots = utility.mod_i(n_subplots_total, nplots_page_max)
nsubplots_page = [nplots_page_max for _ in range(ntotal_figs)]
if nrest_subplots:
nsubplots_page.append(nrest_subplots)
figs = []
fig_axs = []
var_idx = 0
varname_page_idx = 0
for nsubplots in nsubplots_page:
width, height = mpl_papersize("a4", "portrait")
height_subplot = height / nrow_max
nrow = int(num.ceil(nsubplots / ncol))
fig, axs = plt.subplots(nrow, ncol, figsize=(width, height_subplot * nrow))
axs = num.atleast_2d(axs)
for i in range(nsubplots):
coli, rowi = utility.mod_i(i, nrow)
ax = axs[rowi, coli]
if var_idx > n - 1:
try:
fig.delaxes(ax)
except KeyError:
pass
else:
if nvar == 1:
source_idxs = [backup_source_idxs[i]]
v = varnames[var_idx]
var_idx += 1
color = copy.deepcopy(input_color)
for d in trace.get_values(
v, combine=combined, chains=chains, squeeze=False
):
plot_name, transform = get_transform(v)
d = transform(d)
# iterate over columns in case varsize > 1
if v in dist_vars:
if source_idxs is None:
source_idx_step = int(num.floor(d.shape[1] / 6))
logger.info(
"No patches defined using 1 every %i!", source_idx_step
)
source_idxs = num.arange(
0, d.shape[1], source_idx_step
).tolist()
logger.info(
"Plotting patches: %s" % utility.list2string(source_idxs)
)
selected = []
for s_idx in source_idxs:
try:
if isinstance(s_idx, slice):
d_sel = num.atleast_2d(d.T[s_idx].mean(0))
else:
d_sel = num.atleast_2d(d.T[s_idx])
except IndexError:
raise IndexError(
"One or several patches do not exist! "
"Patch idxs: %s" % utility.list2string([s_idx])
)
selected.append(d_sel)
selected = num.vstack(selected)
else:
selected = num.atleast_2d(d.T)
nsources = selected.shape[0]
logger.debug("Number of sources: %i" % nsources)
for isource, e in enumerate(selected):
e = num.atleast_2d(e)
if make_bins_flag:
varbin = make_bins(e, nbins=nbins, qlist=qlist)
varbins.append(varbin)
else:
varbin = varbins[i]
if lines:
if v in lines:
reference = lines[v]
else:
reference = None
else:
reference = None
if color is None:
if nsources == 1:
pcolor = "black"
else:
pcolor = mpl_graph_color(isource)
else:
pcolor = color
if plot_style == "kde":
plot_density(
e,
shade=alpha,
ax=ax,
colors=[pcolor],
backend="matplotlib",
backend_kwargs={
"linewidth": 1.0,
},
)
ax.relim()
ax.autoscale(tight=False)
ax.set_ylim(0)
xax = ax.get_xaxis()
# axs[rowi, coli].set_ylim([0, e.max()])
xticker = MaxNLocator(nbins=5)
xax.set_major_locator(xticker)
elif plot_style in ["pdf", "cdf"]:
kwargs["label"] = source_idxs
# following determine quantile annotations in cdf
kwargs["nsources"] = nsources
kwargs["isource"] = isource
if plot_style == "cdf":
kwargs["cumulative"] = True
else:
kwargs["cumulative"] = False
histplot_op(
ax,
e,
reference=reference,
bins=varbin,
alpha=alpha,
color=pcolor,
qlist=qlist,
kwargs=kwargs,
)
else:
raise NotImplementedError(
'Plot style "%s" not implemented' % plot_style
)
plot_unit = defaults[hypername(plot_name)].unit
try:
param = prior_bounds[v]
if v in dist_vars:
try: # variable bounds
lower = param.lower[tuple(source_idxs)]
upper = param.upper[tuple(source_idxs)]
except IndexError:
lower, upper = param.lower, param.upper
title = "{} {}".format(v, plot_unit)
else:
lower = num.array2string(param.lower, separator=",")[
1:-1
]
upper = num.array2string(param.upper, separator=",")[
1:-1
]
title = "{} {} \npriors: ({}; {})".format(
plot_name,
plot_unit,
lower,
upper,
)
except KeyError:
try:
title = "{} {}".format(plot_name, float(lines[v]))
except KeyError:
title = "{} {}".format(plot_name, plot_unit)
axs[rowi, coli].set_xlabel(title, fontsize=fontsize)
if nvar == 1:
axs[rowi, coli].set_title(
"Patch %s" % utility.list2string(source_idxs),
loc="left",
fontsize=fontsize,
)
ax.grid(grid)
ax.get_yaxis().set_visible(False)
format_axes(axs[rowi, coli])
ax.tick_params(axis="x", labelsize=fontsize)
if lines:
try:
for line in lines[v]:
ax.axvline(x=line, color="white", lw=1.0)
ax.axvline(
x=line,
color="black",
linestyle="dashed",
lw=1.0,
)
except KeyError:
pass
if posterior != "None":
if posterior == "all":
for k, idx in posterior_idxs.items():
ax.axvline(x=e[:, idx], color=colors[k], lw=1.0)
else:
idx = posterior_idxs[posterior]
ax.axvline(x=e[:, idx], color=pcolor, lw=1.0)
if unify:
page_varnames = varnames[varname_page_idx : varname_page_idx + nsubplots]
unities = unify_tick_intervals(
axs, page_varnames, ntickmarks_max=ntickmarks_max, axis="x"
)
apply_unified_axis(
axs, page_varnames, unities, axis="x", scale_factor=scale_factor
)
varname_page_idx += nsubplots
fig.subplots_adjust(wspace=0.05, hspace=0.5)
fig.tight_layout()
figs.append(fig)
fig_axs.append(axs)
return figs, fig_axs, varbins
[docs]
def correlation_plot(
mtrace,
varnames=None,
figsize=None,
cmap=None,
grid=200,
point=None,
point_style=".",
point_color="white",
point_size="8",
):
"""
Plot 2d marginals (with kernel density estimation) showing the correlations
of the model parameters.
Parameters
----------
mtrace : :class:`.base.MutliTrace`
Mutlitrace instance containing the sampling results
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
cmap : matplotlib colormap
grid : resolution of kernel density estimation
point : dict
Dictionary of variable name / value to be overplotted as marker
to the posteriors e.g. mean of posteriors, true values of a simulation
point_style : str
style of marker according to matplotlib conventions
point_color : str or tuple of 3
color according to matplotlib convention
point_size : str
marker size according to matplotlib conventions
Returns
-------
fig : figure object
axs : subplot axis handles
"""
if varnames is None:
varnames = mtrace.varnames
nvar = len(varnames)
if figsize is None:
figsize = mpl_papersize("a4", "landscape")
fig, axs = plt.subplots(
sharey="row", sharex="col", nrows=nvar - 1, ncols=nvar - 1, figsize=figsize
)
d = dict()
for var in varnames:
plot_name, transform = get_transform(var)
vals = transform(mtrace.get_values(var, combine=True, squeeze=True))
_, nvar_elements = vals.shape
if nvar_elements > 1:
raise ValueError(
"Correlation plot can only be displayed for variables "
" with size 1! %s is %i! " % (var, nvar_elements)
)
d[var] = vals
for i_k in range(nvar - 1):
varname_a = varnames[i_k]
a = d[varname_a]
for i_l in range(i_k + 1, nvar):
ax = axs[i_l - 1, i_k]
varname_b = varnames[i_l]
logger.debug("%s, %s" % (varname_a, varname_b))
b = d[varname_b]
kde2plot(a, b, grid=grid, ax=ax, cmap=cmap, aspect="auto")
if point is not None:
ax.plot(
point[varnames[i_k]],
point[varnames[i_l]],
color=point_color,
marker=point_style,
markersize=point_size,
)
ax.tick_params(direction="in")
if i_k == 0:
ax.set_ylabel(varname_b)
axs[i_l - 1, i_k].set_xlabel(varname_a)
for i_k in range(nvar - 1):
for i_l in range(i_k):
fig.delaxes(axs[i_l, i_k])
fig.tight_layout()
fig.subplots_adjust(wspace=0.05, hspace=0.05)
return fig, axs
[docs]
def correlation_plot_hist(
mtrace,
varnames=None,
figsize=None,
hist_color=None,
cmap=None,
grid=50,
chains=None,
ntickmarks=2,
point=None,
point_style=".",
point_color="red",
point_size=4,
alpha=0.35,
unify=True,
):
"""
Plot 2d marginals (with kernel density estimation) showing the correlations
of the model parameters. In the main diagonal is shown the parameter
histograms.
Parameters
----------
mtrace : :class:`pymc.backends.base.MultiTrace`
Mutlitrace instance containing the sampling results
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
cmap : matplotlib colormap
hist_color : str or tuple of 3
color according to matplotlib convention
grid : resolution of kernel density estimation
chains : int or list of ints
chain indexes to select from the trace
ntickmarks : int
number of ticks at the axis labels
point : dict
Dictionary of variable name / value to be overplotted as marker
to the posteriors e.g. mean of posteriors, true values of a simulation
point_style : str
style of marker according to matplotlib conventions
point_color : str or tuple of 3
color according to matplotlib convention
point_size : str
marker size according to matplotlib conventions
unify: bool
If true axis units that belong to one group e.g. [km] will
have common axis increments
Returns
-------
fig : figure object
axs : subplot axis handles
"""
fontsize = 9
ntickmarks_max = 2
label_pad = 25
logger.info("Drawing correlation figure ...")
logger.warning("Does NOT seperate parameters correctly for Mixed Type Setups!")
if varnames is None:
varnames = mtrace.varnames
nvar = len(varnames)
if figsize is None:
if nvar < 5:
figsize = mpl_papersize("a5", "landscape")
else:
figsize = mpl_papersize("a4", "landscape")
d = dict()
for var in varnames:
_, transform = get_transform(var)
vals = transform(
mtrace.get_values(var, chains=chains, combine=True, squeeze=True)
)
logger.info("Getting data for `%s` from sampled trace." % var)
try:
_, nvar_elements = vals.shape
except ValueError: # for variables woth dim=1
nvar_elements = 1
vals = num.atleast_2d(vals).T
d[var] = vals
figs = []
axes = []
for source_i in range(nvar_elements):
logger.info("for variables of source %i ..." % source_i)
hist_ylims = []
fig, axs = plt.subplots(nrows=nvar, ncols=nvar, figsize=figsize)
if hist_color is None:
if nvar_elements == 1:
pcolor = "orange"
else:
pcolor = mpl_graph_color(source_i)
else:
pcolor = hist_color
for i_k in range(nvar):
v_namea = varnames[i_k]
a = d[v_namea][:, source_i]
for i_l in range(i_k, nvar):
ax = axs[i_l, i_k]
v_nameb = varnames[i_l]
plot_name_a, transform_a = get_transform(v_namea)
plot_name_b, transform_b = get_transform(v_nameb)
logger.debug("%s, %s" % (v_namea, v_nameb))
if i_l == i_k:
if point is not None:
if v_namea in point.keys():
reference = transform_a(point[v_namea][source_i])
ax.axvline(
x=reference, color=point_color, lw=point_size / 4.0
)
else:
reference = None
else:
reference = None
histplot_op(
ax,
num.atleast_2d(a),
alpha=alpha,
color=pcolor,
tstd=0.0,
reference=reference,
)
ax.get_yaxis().set_visible(False)
format_axes(ax)
xticks = ax.get_xticks()
xlim = ax.get_xlim()
hist_ylims.append(ax.get_ylim())
else:
b = d[v_nameb][:, source_i]
kde2plot(a, b, grid=grid, ax=ax, cmap=cmap, aspect="auto")
bmin = b.min()
bmax = b.max()
if point is not None:
if v_namea and v_nameb in point.keys():
value_vara = transform_a(point[v_namea][source_i])
value_varb = transform_b(point[v_nameb][source_i])
ax.plot(
value_vara,
value_varb,
color=point_color,
marker=point_style,
markersize=point_size,
)
bmin = num.minimum(bmin, value_varb)
bmax = num.maximum(bmax, value_varb)
yticker = MaxNLocator(nbins=ntickmarks)
ax.set_xticks(xticks)
ax.set_xlim(xlim)
yax = ax.get_yaxis()
yax.set_major_locator(yticker)
if i_l != nvar - 1:
ax.get_xaxis().set_ticklabels([])
if i_k == 0:
ax.set_ylabel(
plot_name_b + "\n " + defaults[hypername(plot_name_b)].unit,
fontsize=fontsize,
)
if utility.is_odd(i_l):
ax.tick_params(axis="y", pad=label_pad)
else:
ax.get_yaxis().set_ticklabels([])
ax.tick_params(axis="both", direction="in", labelsize=fontsize)
try: # matplotlib version issue workaround
ax.tick_params(axis="both", labelrotation=50.0)
except Exception:
ax.set_xticklabels(axs[i_l, i_k].get_xticklabels(), rotation=50)
ax.set_yticklabels(axs[i_l, i_k].get_yticklabels(), rotation=50)
if utility.is_odd(i_k):
ax.tick_params(axis="x", pad=label_pad)
# put transformed varname back to varnames for unification
# varnames[k] = plot_name_a
ax.set_xlabel(
plot_name_a + "\n " + defaults[hypername(plot_name_a)].unit,
fontsize=fontsize,
)
if unify:
varnames_repeat_x = [
var_reap for varname in varnames for var_reap in (varname,) * nvar
]
varnames_repeat_y = varnames * nvar
unitiesx = unify_tick_intervals(
axs, varnames_repeat_x, ntickmarks_max=ntickmarks_max, axis="x"
)
apply_unified_axis(
axs,
varnames_repeat_x,
unitiesx,
axis="x",
scale_factor=1.0,
ntickmarks_max=ntickmarks_max,
)
apply_unified_axis(
axs,
varnames_repeat_y,
unitiesx,
axis="y",
scale_factor=1.0,
ntickmarks_max=ntickmarks_max,
)
for i_k in range(nvar):
if unify:
# reset histogram ylims after unify
axs[i_k, i_k].set_ylim(hist_ylims[i_k])
for i_l in range(i_k):
fig.delaxes(axs[i_l, i_k])
fig.tight_layout()
fig.subplots_adjust(wspace=0.05, hspace=0.05)
figs.append(fig)
axes.append(axs)
return figs, axes
[docs]
def draw_posteriors(problem, plot_options):
"""
Identify which stage is the last complete stage and plot posteriors.
"""
plot_style_choices = ["pdf", "cdf", "kde", "local"]
hypers = utility.check_hyper_flag(problem)
po = plot_options
if po.plot_projection in plot_style_choices:
if po.plot_projection == "local":
plot_style = "pdf"
nbins = 40
else:
plot_style = po.plot_projection
nbins = 200
else:
raise ValueError(
"Supported plot-projections are: %s"
% utility.list2string(plot_style_choices)
)
logger.info('Plotting "%s"' % plot_style)
stage = Stage(
homepath=problem.outfolder, backend=problem.config.sampler_config.backend
)
pc = problem.config.problem_config
list_indexes = stage.handler.get_stage_indexes(po.load_stage)
if hypers:
varnames = problem.hypernames + ["like"]
else:
varnames = (
problem.varnames + problem.hypernames + problem.hierarchicalnames + ["like"]
)
if len(po.varnames) > 0:
varnames = po.varnames
logger.info("Plotting variables: %s" % (", ".join((v for v in varnames))))
figs = []
for s in list_indexes:
if po.source_idxs:
sidxs = utility.list2string(po.source_idxs, fill="_")
else:
sidxs = ""
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
"stage_%i_%s_%s_%s" % (s, sidxs, po.post_llk, plot_style),
)
if plot_exists(outpath, po.outformat, po.force):
return
logger.info("plotting stage: %s" % stage.handler.stage_path(s))
stage.load_results(
varnames=problem.varnames,
model=problem.model,
stage_number=s,
load="trace",
chains=[-1],
)
prior_bounds = {}
prior_bounds.update(**pc.hyperparameters)
prior_bounds.update(**pc.hierarchicals)
prior_bounds.update(**pc.priors)
figs, _, _ = traceplot(
stage.mtrace,
varnames=varnames,
chains=None,
combined=True,
source_idxs=po.source_idxs,
plot_style=plot_style,
lines=po.reference,
posterior=po.post_llk,
prior_bounds=prior_bounds,
nbins=nbins,
)
save_figs(figs, outpath, po.outformat, po.dpi)
[docs]
def draw_correlation_hist(problem, plot_options):
"""
Draw parameter correlation plot and histograms for a model result ensemble.
Only feasible for 'geometry' problem.
"""
po = plot_options
mode = problem.config.problem_config.mode
if mode not in [geometry_mode_str, bem_mode_str]:
raise NotImplementedError(f"The correlation plot is not implemented for {mode}")
assert po.load_stage != 0
hypers = utility.check_hyper_flag(problem)
if hypers:
varnames = problem.hypernames
else:
varnames = list(problem.varnames)
if len(po.varnames) > 0:
varnames = po.varnames
logger.info("Plotting variables: %s" % (", ".join((v for v in varnames))))
if len(varnames) < 2:
raise TypeError(
"Need at least two parameters to compare!"
"Found only %i variables! " % len(varnames)
)
stage = load_stage(problem, stage_number=po.load_stage, load="trace", chains=[-1])
if not po.reference:
reference = get_result_point(stage.mtrace, po.post_llk)
llk_str = po.post_llk
else:
reference = po.reference
llk_str = "ref"
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
"corr_hist_%s_%s" % (stage.number, llk_str),
)
if plot_exists(outpath, po.outformat, po.force):
return
figs, _ = correlation_plot_hist(
mtrace=stage.mtrace,
varnames=varnames,
cmap=plt.cm.gist_earth_r,
chains=None,
point=reference,
point_size=6,
point_color="red",
)
save_figs(figs, outpath, po.outformat, po.dpi)