from pymc3 import plots as pmp
from pymc3 import quantiles
from collections import OrderedDict
import math
import os
import logging
import copy
from beat import utility
from beat.models import Stage, load_stage
from beat.models.corrections import StrainRateCorrection
from beat.sampler.metropolis import get_trace_stats
from beat.heart import (init_seismic_targets, init_geodetic_targets,
physical_bounds, StrainRateTensor)
from beat.config import ffi_mode_str, geometry_mode_str, dist_vars
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle, FancyArrow
from matplotlib.collections import PatchCollection
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.ticker as tick
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.stats import kde
import numpy as num
from pyrocko.guts import (Object, String, Dict, List,
Bool, Int, load, StringChoice)
from pyrocko import util, trace
from pyrocko import cake_plot as cp
from pyrocko import orthodrome as otd
from pyrocko.cake_plot import str_to_mpl_color as scolor
from pyrocko.cake_plot import light
from pyrocko.plot import beachball, nice_value, AutoScaler
from pyrocko import gmtpy
import pyrocko.moment_tensor as mt
from pyrocko.plot import mpl_papersize, mpl_init, mpl_graph_color, mpl_margins
logger = logging.getLogger('plotting')
km = 1000.
__all__ = [
'PlotOptions', 'correlation_plot', 'correlation_plot_hist',
'get_result_point', 'seismic_fits', 'scene_fits', 'traceplot',
'histplot_op']
u_nm = '$[Nm]$'
u_km = '$[km]$'
u_km_s = '$[km/s]$'
u_deg = '$[^{\circ}]$'
u_deg_myr = '$[^{\circ} / myr]$'
u_m = '$[m]$'
u_v = '$[m^3]$'
u_s = '$[s]$'
u_rad = '$[rad]$'
u_hyp = ''
u_nanostrain = 'nstrain'
plot_units = {
'east_shift': u_km,
'north_shift': u_km,
'depth': u_km,
'width': u_km,
'length': u_km,
'dip': u_deg,
'dip1': u_deg,
'dip2': u_deg,
'strike': u_deg,
'strike1': u_deg,
'strike2': u_deg,
'rake': u_deg,
'rake1': u_deg,
'rake2': u_deg,
'mix': u_hyp,
'volume_change': u_v,
'diameter': u_km,
'slip': u_m,
'opening_fraction': u_hyp,
'azimuth': u_deg,
'bl_azimuth': u_deg,
'amplitude': u_nm,
'bl_amplitude': u_m,
'locking_depth': u_km,
'nucleation_dip': u_km,
'nucleation_strike': u_km,
'nucleation_x': u_hyp,
'nucleation_y': u_hyp,
'time_shift': u_s,
'uperp': u_m,
'uparr': u_m,
'utens': u_m,
'durations': u_s,
'velocities': u_km_s,
'mnn': u_nm,
'mee': u_nm,
'mdd': u_nm,
'mne': u_nm,
'mnd': u_nm,
'med': u_nm,
'magnitude': u_hyp,
'eps_xx': u_nanostrain,
'eps_yy': u_nanostrain,
'eps_xy': u_nanostrain,
'rotation': u_nanostrain,
'pole_lat': u_deg,
'pole_lon': u_deg,
'omega': u_deg_myr,
'w': u_rad,
'v': u_rad,
'kappa': u_rad,
'sigma': u_rad,
'h': u_hyp,
'distance': u_km,
'delta_depth': u_km,
'delta_time': u_s,
'time': u_s,
'duration': u_s,
'peak_ratio': u_hyp,
'h_': u_hyp,
'like': u_hyp}
plot_projections = ['latlon', 'local', 'individual']
def hypername(varname):
if varname in list(plot_units.keys()):
return varname
else:
return 'h_'
[docs]class PlotOptions(Object):
post_llk = String.T(
default='max',
help='Which model to plot on the specified plot; Default: "max";'
' Options: "max", "min", "mean", "all"')
plot_projection = StringChoice.T(
default='local',
choices=plot_projections,
help='Projection to use for plotting geodetic data; options: "latlon"')
utm_zone = Int.T(
default=36,
optional=True,
help='Only relevant if plot_projection is "utm"')
load_stage = Int.T(
default=-1,
help='Which stage to select for plotting')
figure_dir = String.T(
default='figures',
help='Name of the output directory of plots')
reference = Dict.T(
default={},
help='Reference point for example from a synthetic test.',
optional=True)
outformat = String.T(default='pdf')
dpi = Int.T(default=300)
force = Bool.T(default=False)
varnames = List.T(
default=[], optional=True, help='Names of variables to plot')
source_idxs = List.T(
default=None,
optional=True,
help='Indexes to patches of slip distribution to draw marginals for')
nensemble = Int.T(
default=1,
help='Number of draws from the PPD to display fuzzy results.')
def str_unit(quantity):
"""
Return string representation of waveform unit.
"""
if quantity == 'displacement':
return '$m$'
elif quantity == 'velocity':
return '$m/s$'
elif quantity == 'acceleration':
return '$m/s^2$'
else:
raise ValueError('Quantity %s not supported!' % quantity)
def str_dist(dist):
"""
Return string representation of distance.
"""
if dist < 10.0:
return '%g m' % dist
elif 10. <= dist < 1. * km:
return '%.0f m' % dist
elif 1. * km <= dist < 10. * km:
return '%.1f km' % (dist / km)
else:
return '%.0f km' % (dist / km)
def str_duration(t):
"""
Convert time to str representation.
"""
s = ''
if t < 0.:
s = '-'
t = abs(t)
if t < 60.0:
return s + '%.2g s' % t
elif 60.0 <= t < 3600.:
return s + util.time_to_str(t, format='%M:%S min')
elif 3600. <= t < 24 * 3600.:
return s + util.time_to_str(t, format='%H:%M h')
else:
return s + '%.1f d' % (t / (24. * 3600.))
def kde2plot_op(ax, x, y, grid=200, **kwargs):
xmin = x.min()
xmax = x.max()
ymin = y.min()
ymax = y.max()
extent = kwargs.pop('extent', [])
if len(extent) != 4:
extent = [xmin, xmax, ymin, ymax]
grid = grid * 1j
X, Y = num.mgrid[xmin:xmax:grid, ymin:ymax:grid]
positions = num.vstack([X.ravel(), Y.ravel()])
values = num.vstack([x.ravel(), y.ravel()])
kernel = kde.gaussian_kde(values)
Z = num.reshape(kernel(positions).T, X.shape)
ax.imshow(num.rot90(Z), extent=extent, **kwargs)
def kde2plot(x, y, grid=200, ax=None, **kwargs):
if ax is None:
_, ax = plt.subplots(1, 1, squeeze=True)
kde2plot_op(ax, x, y, grid, **kwargs)
return ax
def spherical_kde_op(
lats0, lons0, lats=None, lons=None, grid_size=(200, 200), sigma=None):
from beat.models.distributions import vonmises_fisher, vonmises_std
if sigma is None:
logger.debug('No sigma given, estimating VonMisesStd ...')
sigmahat = vonmises_std(lats=lats0, lons=lons0)
sigma = 1.06 * sigmahat * lats0.size ** -0.2
logger.debug(
'suggested sigma: %f, sigmahat: %f' % (sigma, sigmahat))
if lats is None and lons is None:
lats_vec = num.linspace(-90., 90, grid_size[0])
lons_vec = num.linspace(-180., 180, grid_size[1])
lons, lats = num.meshgrid(lons_vec, lats_vec)
if lats is not None:
assert lats.size == lons.size
batch_size = 500
cycles, rest = utility.mod_i(lats0.size, batch_size)
if rest != 0:
logger.debug(
'Processing rest of spherical kde samples %i' % (rest))
vmf = vonmises_fisher(
lats=lats, lons=lons,
lats0=lats0[-rest:], lons0=lons0[-rest:], sigma=sigma)
kde = num.exp(vmf).sum(axis=-1).reshape( # , b=self.weights)
(grid_size[0], grid_size[1]))
else:
logger.debug(
'Init new spherical kde samples')
kde = num.zeros(grid_size)
logger.info('Drawing lune plot for %i samples ... ' % lats0.size)
for cyc in range(cycles):
cyc_s = cyc * batch_size
cyc_e = cyc_s + batch_size
logger.debug(
'Looping over spherical kde samples %i - %i' % (cyc_s, cyc_e))
vmf = vonmises_fisher(
lats=lats, lons=lons,
lats0=lats0[cyc_s:cyc_e], lons0=lons0[cyc_s:cyc_e], sigma=sigma)
kde += num.exp(vmf).sum(axis=-1)
return kde, lats, lons
[docs]def correlation_plot(
mtrace, varnames=None,
transform=lambda x: x, figsize=None, cmap=None, grid=200, point=None,
point_style='.', point_color='white', point_size='8'):
"""
Plot 2d marginals (with kernel density estimation) showing the correlations
of the model parameters.
Parameters
----------
mtrace : :class:`pymc3.base.MutliTrace`
Mutlitrace instance containing the sampling results
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
transform : callable
Function to transform data (defaults to identity)
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
cmap : matplotlib colormap
grid : resolution of kernel density estimation
point : dict
Dictionary of variable name / value to be overplotted as marker
to the posteriors e.g. mean of posteriors, true values of a simulation
point_style : str
style of marker according to matplotlib conventions
point_color : str or tuple of 3
color according to matplotlib convention
point_size : str
marker size according to matplotlib conventions
Returns
-------
fig : figure object
axs : subplot axis handles
"""
if varnames is None:
varnames = mtrace.varnames
nvar = len(varnames)
if figsize is None:
figsize = mpl_papersize('a4', 'landscape')
fig, axs = plt.subplots(
sharey='row', sharex='col',
nrows=nvar - 1, ncols=nvar - 1, figsize=figsize)
d = dict()
for var in varnames:
vals = transform(
mtrace.get_values(
var, combine=True, squeeze=True))
_, nvar_elements = vals.shape
if nvar_elements > 1:
raise ValueError(
'Correlation plot can only be displayed for variables '
' with size 1! %s is %i! ' % (var, nvar_elements))
d[var] = vals
for k in range(nvar - 1):
a = d[varnames[k]]
for l in range(k + 1, nvar):
logger.debug('%s, %s' % (varnames[k], varnames[l]))
b = d[varnames[l]]
kde2plot(
a, b, grid=grid, ax=axs[l - 1, k], cmap=cmap, aspect='auto')
if point is not None:
axs[l - 1, k].plot(
point[varnames[k]], point[varnames[l]],
color=point_color, marker=point_style,
markersize=point_size)
axs[l - 1, k].tick_params(direction='in')
if k == 0:
axs[l - 1, k].set_ylabel(varnames[l])
axs[l - 1, k].set_xlabel(varnames[k])
for k in range(nvar - 1):
for l in range(k):
fig.delaxes(axs[l, k])
fig.tight_layout()
fig.subplots_adjust(wspace=0.05, hspace=0.05)
return fig, axs
[docs]def correlation_plot_hist(
mtrace, varnames=None,
transform=lambda x: x, figsize=None, hist_color='orange', cmap=None,
grid=50, chains=None, ntickmarks=2, point=None,
point_style='.', point_color='red', point_size=4, alpha=0.35,
unify=True):
"""
Plot 2d marginals (with kernel density estimation) showing the correlations
of the model parameters. In the main diagonal is shown the parameter
histograms.
Parameters
----------
mtrace : :class:`pymc3.base.MutliTrace`
Mutlitrace instance containing the sampling results
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
transform : callable
Function to transform data (defaults to identity)
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
cmap : matplotlib colormap
hist_color : str or tuple of 3
color according to matplotlib convention
grid : resolution of kernel density estimation
chains : int or list of ints
chain indexes to select from the trace
ntickmarks : int
number of ticks at the axis labels
point : dict
Dictionary of variable name / value to be overplotted as marker
to the posteriors e.g. mean of posteriors, true values of a simulation
point_style : str
style of marker according to matplotlib conventions
point_color : str or tuple of 3
color according to matplotlib convention
point_size : str
marker size according to matplotlib conventions
unify: bool
If true axis units that belong to one group e.g. [km] will
have common axis increments
Returns
-------
fig : figure object
axs : subplot axis handles
"""
fontsize = 9
ntickmarks_max = 2
label_pad = 25
logger.info('Drawing correlation figure ...')
if varnames is None:
varnames = mtrace.varnames
nvar = len(varnames)
if figsize is None:
if nvar < 5:
figsize = mpl_papersize('a5', 'landscape')
else:
figsize = mpl_papersize('a4', 'landscape')
d = dict()
for var in varnames:
vals = transform(
mtrace.get_values(
var, chains=chains, combine=True, squeeze=True))
_, nvar_elements = vals.shape
d[var] = vals
figs = []
axes = []
for source_i in range(nvar_elements):
logger.info('for variables of source %i ...' % source_i)
hist_ylims = []
fig, axs = plt.subplots(nrows=nvar, ncols=nvar, figsize=figsize)
for k in range(nvar):
v_namea = varnames[k]
a = d[v_namea][:, source_i]
pcolor = mpl_graph_color(source_i)
for l in range(k, nvar):
v_nameb = varnames[l]
logger.debug('%s, %s' % (v_namea, v_nameb))
if l == k:
if point is not None:
if v_namea in point.keys():
reference = point[v_namea][source_i]
axs[l, k].axvline(
x=reference, color=point_color,
lw=point_size / 4.)
else:
reference = None
else:
reference = None
histplot_op(
axs[l, k], pmp.utils.make_2d(a), alpha=alpha,
color=pcolor, tstd=0., reference=reference)
axs[l, k].get_yaxis().set_visible(False)
format_axes(axs[l, k])
xticks = axs[l, k].get_xticks()
xlim = axs[l, k].get_xlim()
hist_ylims.append(axs[l, k].get_ylim())
else:
b = d[v_nameb][:, source_i]
kde2plot(
a, b, grid=grid, ax=axs[l, k],
cmap=cmap, aspect='auto')
bmin = b.min()
bmax = b.max()
if point is not None:
if v_namea and v_nameb in point.keys():
va = point[v_namea][source_i]
vb = point[v_nameb][source_i]
axs[l, k].plot(
va, vb,
color=point_color, marker=point_style,
markersize=point_size)
bmin = num.minimum(bmin, vb)
bmax = num.maximum(bmax, vb)
yticker = tick.MaxNLocator(nbins=ntickmarks)
axs[l, k].set_xticks(xticks)
axs[l, k].set_xlim(xlim)
yax = axs[l, k].get_yaxis()
yax.set_major_locator(yticker)
if l != nvar - 1:
axs[l, k].get_xaxis().set_ticklabels([])
if k == 0:
axs[l, k].set_ylabel(
v_nameb + '\n ' + plot_units[hypername(v_nameb)],
fontsize=fontsize)
if utility.is_odd(l):
axs[l, k].tick_params(axis='y', pad=label_pad)
else:
axs[l, k].get_yaxis().set_ticklabels([])
axs[l, k].tick_params(
axis='both', direction='in', labelsize=fontsize)
try: # matplotlib version issue workaround
axs[l, k].tick_params(
axis='both', labelrotation=50.)
except Exception:
axs[l, k].set_xticklabels(
axs[l, k].get_xticklabels(), rotation=50)
axs[l, k].set_yticklabels(
axs[l, k].get_yticklabels(), rotation=50)
if utility.is_odd(k):
axs[l, k].tick_params(axis='x', pad=label_pad)
axs[l, k].set_xlabel(
v_namea + '\n ' + plot_units[hypername(v_namea)],
fontsize=fontsize)
if unify:
varnames_repeat_x = [
var_reap for varname in varnames for var_reap in (varname,) * nvar]
varnames_repeat_y = varnames * nvar
unitiesx = unify_tick_intervals(
axs, varnames_repeat_x, ntickmarks_max=ntickmarks_max, axis='x')
apply_unified_axis(
axs, varnames_repeat_x, unitiesx, axis='x', scale_factor=1.,
ntickmarks_max=ntickmarks_max)
apply_unified_axis(
axs, varnames_repeat_y, unitiesx, axis='y', scale_factor=1.,
ntickmarks_max=ntickmarks_max)
for k in range(nvar):
if unify:
# reset histogram ylims after unify
axs[k, k].set_ylim(hist_ylims[k])
for l in range(k):
fig.delaxes(axs[l, k])
fig.tight_layout()
fig.subplots_adjust(wspace=0.05, hspace=0.05)
figs.append(fig)
axes.append(axs)
return figs, axes
def plot(uwifg, point_size=20):
"""
Very simple scatter plot of given IFG for fast inspections.
Parameters
----------
point_size : int
determines the size of the scatter plot points
"""
ax = plt.axes()
im = ax.scatter(
uwifg.lons, uwifg.lats, point_size, uwifg.displacement,
edgecolors='none')
plt.colorbar(im)
plt.title('Displacements [m] %s' % uwifg.name)
plt.show()
def plot_cov(target, point_size=20):
ax = plt.axes()
im = ax.scatter(
target.lons, target.lats, point_size,
num.array(target.covariance.pred_v.sum(axis=0)).flatten(),
edgecolors='none')
plt.colorbar(im)
plt.title('Prediction Covariance [m2] %s' % target.name)
plt.show()
def plot_matrix(A):
"""
Very simple plot of a matrix for fast inspections.
"""
ax = plt.axes()
im = ax.matshow(A)
plt.colorbar(im)
plt.show()
def plot_log_cov(cov_mat):
ax = plt.axes()
mask = num.ones_like(cov_mat)
mask[cov_mat < 0] = -1.
im = ax.imshow(num.multiply(num.log(num.abs(cov_mat)), mask))
plt.colorbar(im)
plt.show()
def get_result_point(mtrace, point_llk='max'):
if point_llk != 'None':
llk = mtrace.get_values(
varname='like',
combine=True)
posterior_idxs = utility.get_fit_indexes(llk)
point = mtrace.point(idx=posterior_idxs[point_llk])
else:
point = None
return point
def plot_quadtree(ax, data, target, cmap, colim, alpha=0.8):
"""
Plot UnwrappedIFG displacements on the respective quadtree rectangle.
"""
rectangles = []
for E, N, sE, sN in target.quadtree.iter_leaves():
rectangles.append(
Rectangle(
(E / km, N / km),
width=sE / km,
height=sN / km,
edgecolor='black'))
patch_col = PatchCollection(
rectangles, match_original=True, alpha=alpha, linewidth=0.5)
patch_col.set(array=data, cmap=cmap)
patch_col.set_clim((-colim, colim))
E = target.quadtree.east_shifts
N = target.quadtree.north_shifts
xmin = E.min() / km
xmax = (E + target.quadtree.sizeE).max() / km
ymin = N.min() / km
ymax = (N + target.quadtree.sizeN).max() / km
ax.add_collection(patch_col)
ax.set_xlim((xmin, xmax))
ax.set_ylim((ymin, ymax))
return patch_col
def plot_scene(ax, target, data, scattersize, colim,
outmode='latlon', **kwargs):
if outmode == 'latlon':
x = target.lons
y = target.lats
elif outmode == 'local':
if target.quadtree is not None:
cmap = kwargs.pop('cmap', plt.cm.jet)
return plot_quadtree(ax, data, target, cmap, colim)
else:
x = target.east_shifts / km
y = target.north_shifts / km
return ax.scatter(
x, y, scattersize, data,
edgecolors='none', vmin=-colim, vmax=colim, **kwargs)
def format_axes(
ax, remove=['right', 'top', 'left'], linewidth=None, visible=False):
"""
Removes box top, left and right.
"""
for rm in remove:
ax.spines[rm].set_visible(visible)
if linewidth is not None:
ax.spines[rm].set_linewidth(linewidth)
def scale_axes(axis, scale, offset=0.):
from matplotlib.ticker import ScalarFormatter
class FormatScaled(ScalarFormatter):
@staticmethod
def __call__(value, pos):
return '{:,.1f}'.format(offset + value * scale).replace(',', ' ')
axis.set_major_formatter(FormatScaled())
def set_anchor(sources, anchor):
for source in sources:
source.anchor = anchor
def get_gmt_colorstring_from_mpl(i):
color = (num.array(mpl_graph_color(i)) * 255).tolist()
return utility.list2string(color, '/')
def gnss_fits(problem, stage, plot_options):
from pyrocko import automap
from pyrocko.model import gnss
if len(gmtpy.detect_gmt_installations()) < 1:
raise gmtpy.GmtPyError(
'GMT needs to be installed for GNSS plot!')
gc = problem.config.geodetic_config
try:
ds_config = gc.types['GNSS']
except KeyError:
raise ImportError('No GNSS data in configuration!')
logger.info('Trying to load GNSS data from: {}'.format(ds_config.datadir))
campaigns = ds_config.load_data(campaign=True)
if not campaigns:
raise ImportError(
'Did not fing GNSS data under %s' % ds_config.datadir)
if len(campaigns) > 1:
logger.warning(
'Plotting for more than 1 GNSS dataset needs tp be implemented')
campaign = campaigns[0]
datatype = 'geodetic'
mode = problem.config.problem_config.mode
problem.init_hierarchicals()
figsize = 20. # size in cm
po = plot_options
composite = problem.composites[datatype]
try:
sources = composite.sources
ref_sources = None
except AttributeError:
logger.info('FFI gnss fit, using reference source ...')
ref_sources = composite.config.gf_config.reference_sources
set_anchor(ref_sources, anchor='top')
fault = composite.load_fault_geometry()
sources = fault.get_all_subfaults(
datatype=datatype, component=composite.slip_varnames[0])
set_anchor(sources, anchor='top')
if po.reference:
if mode != ffi_mode_str:
composite.point2sources(po.reference)
ref_sources = copy.deepcopy(composite.sources)
bpoint = po.reference
else:
bpoint = get_result_point(stage.mtrace, po.post_llk)
results = composite.assemble_results(bpoint)
bvar_reductions = composite.get_variance_reductions(
bpoint, weights=composite.weights, results=results)
dataset_to_result = {}
for dataset, result in zip(composite.datasets, results):
if dataset.typ == 'GNSS':
dataset_to_result[dataset] = result
if po.plot_projection == 'latlon':
event = problem.config.event
locations = campaign.stations # + [event]
# print(locations)
# lat, lon = otd.geographic_midpoint_locations(locations)
coords = num.array([loc.effective_latlon for loc in locations])
lat, lon = num.mean(num.vstack([coords.min(0), coords.max(0)]), axis=0)
elif po.plot_projection == 'local':
lat, lon = otd.geographic_midpoint_locations(sources)
coords = num.hstack(
[source.outline(cs='latlon').T for source in sources]).T
else:
raise NotImplementedError(
'%s projection not implemented!' % po.plot_projection)
if po.nensemble > 1:
from tqdm import tqdm
logger.info(
'Collecting ensemble of %i '
'synthetic displacements ...' % po.nensemble)
nchains = len(stage.mtrace)
csteps = float(nchains) / po.nensemble
idxs = num.floor(num.arange(0, nchains, csteps)).astype('int32')
ens_results = []
#points = []
ens_var_reductions = []
for idx in tqdm(idxs):
point = stage.mtrace.point(idx=idx)
#points.append(point)
e_results = composite.assemble_results(point)
ens_results.append(e_results)
ens_var_reductions.append(
composite.get_variance_reductions(
point, weights=composite.weights, results=e_results))
all_var_reductions = {}
bvar_reductions_comp = {}
for dataset in dataset_to_result.keys():
target_var_reds = []
target_bvar_red = bvar_reductions[dataset.name]
target_var_reds.append(target_bvar_red)
bvar_reductions_comp[dataset.component] = target_bvar_red
for var_reds in ens_var_reductions:
target_var_reds.append(var_reds[dataset.name])
all_var_reductions[dataset.component] = num.array(
target_var_reds) * 100.
radius = otd.distance_accurate50m_numpy(
lat[num.newaxis], lon[num.newaxis],
coords[:, 0], coords[:, 1]).max()
radius *= 1.2
if radius < 30. * km:
logger.warning(
'Radius of GNSS campaign %s too small, defaulting'
' to 30 km' % campaign.name)
radius = 30 * km
model_camp = gnss.GNSSCampaign(
stations=copy.deepcopy(campaign.stations),
name='model')
for dataset, result in dataset_to_result.items():
for ista, sta in enumerate(model_camp.stations):
comp = getattr(sta, dataset.component)
comp.shift = result.processed_syn[ista]
comp.sigma = 0.
plot_component_flags = []
if 'east' in ds_config.components or 'north' in ds_config.components:
plot_component_flags.append(False)
if 'up' in ds_config.components:
plot_component_flags.append(True)
figs = []
for vertical in plot_component_flags:
m = automap.Map(
width=figsize,
height=figsize,
lat=lat,
lon=lon,
radius=radius,
show_topo=False,
show_grid=True,
show_rivers=True,
color_wet=(216, 242, 254),
color_dry=(238, 236, 230))
all_stations = campaign.stations + model_camp.stations
offset_scale = num.zeros(len(all_stations))
for ista, sta in enumerate(all_stations):
for comp in sta.components.values():
offset_scale[ista] += comp.shift
offset_scale = num.sqrt(offset_scale ** 2).max()
if len(campaign.stations) > 40:
logger.warning('More than 40 stations disabling station labels ..')
labels = False
else:
labels = True
m.add_gnss_campaign(
campaign,
psxy_style={
'G': 'black',
'W': '0.8p,black',
},
offset_scale=offset_scale,
vertical=vertical,
labels=labels)
m.add_gnss_campaign(
model_camp,
psxy_style={
'G': 'red',
'W': '0.8p,red',
't': 30,
},
offset_scale=offset_scale,
vertical=vertical,
labels=False)
for i, source in enumerate(sources):
in_rows = source.outline(cs='lonlat')
if mode != ffi_mode_str:
color = (num.array(mpl_graph_color(i)) * 255).tolist()
color_str = utility.list2string(color, '/')
else:
color_str = 'black'
if in_rows.shape[0] > 1: # finite source
m.gmt.psxy(
in_rows=in_rows,
L='+p0.1p,%s' % color_str,
W='0.1p,black',
G=color_str,
t=70,
*m.jxyr)
m.gmt.psxy(
in_rows=in_rows[0:2],
W='1p,black',
*m.jxyr)
else: # point source
source_scale_factor = 2.
m.gmt.psxy(
in_rows=in_rows,
W='0.1p,black',
G=color_str,
S='c%fp' % float(source.magnitude * source_scale_factor),
t=70,
*m.jxyr)
if dataset:
# plot strain rate tensor
if dataset.has_correction:
for i, corr in enumerate(dataset.corrections):
if isinstance(corr, StrainRateCorrection):
lats, lons = corr.get_station_coordinates()
mid_lat, mid_lon = otd.geographic_midpoint(lats, lons)
corr_point = corr.get_point_rvs(bpoint)
srt = StrainRateTensor.from_point(corr_point)
in_rows = [(
mid_lon, mid_lat, srt.eps1, srt.eps2, srt.azimuth)]
color_str = get_gmt_colorstring_from_mpl(i)
m.gmt.psvelo(
in_rows=in_rows,
S='x%f' % offset_scale,
A='9p+g%s+p1p' % color_str,
W=color_str,
*m.jxyr)
m.draw_axes()
if po.nensemble > 1:
if vertical:
var_reductions_ens = all_var_reductions['up']
else:
var_reductions_tmp = []
if 'east' in all_var_reductions:
var_reductions_tmp.append(all_var_reductions['east'])
if 'north' in all_var_reductions:
var_reductions_tmp.append(all_var_reductions['north'])
var_reductions_ens = num.mean(var_reductions_tmp, axis=0)
vmin, vmax = var_reductions_ens.min(), var_reductions_ens.max()
inc = nice_value(vmax - vmin)
autos = AutoScaler(inc=inc, snap='on', approx_ticks=2)
imin, imax, sinc = autos.make_scale(
(vmin, vmax),
override_mode='min-max')
nbins = 50
args = ['-Bxa%ff%f+lVR [%s]' % (sinc, sinc, '%'),
'-Bya', # dummy large value to avoid ticks
'-BwSne']
# draw white background box for histogram
m.gmt.psbasemap(
D='n0.722/0.716+w4c/4c',
F='+gwhite+p0.25p',
*m.jxyr)
m.gmt.pshistogram(
in_rows=pmp.utils.make_2d(
all_var_reductions[dataset.component]),
W=(imax - imin) / nbins,
G='lightorange',
F=True,
L='0.5p,orange',
J='X4c/4c',
X='f13.5c',
Y='f13.4c',
*args)
# plot vertical line on hist with best solution
# best_data=bvar_reductions[dataset.name] * 100.,
figs.append(m)
return figs
def map_displacement_grid(displacements, scene):
arr = num.full_like(scene.displacement, fill_value=num.nan)
qt = scene.quadtree
for syn_v, l in zip(displacements, qt.leaves):
arr[l._slice_rows, l._slice_cols] = syn_v
arr[scene.displacement_mask] = num.nan
return arr
def shaded_displacements(
displacement, shad_data,
cmap='RdBu', shad_lim=(.4, .98), tick_step=0.01,
contrast=1., mask=None, data_limits=(-0.5, 0.5)):
"""
Map color data (displacement) on shaded relief.
"""
from scipy.ndimage import convolve as im_conv
from matplotlib.cm import ScalarMappable
# Light source from somewhere above - psychologically the best choice
# from upper left
ramp = num.array([[1, 0], [0, -1.]]) * contrast
# convolution of two 2D arrays
shad = im_conv(shad_data * km, ramp.T)
shad *= -1.
# if there are strong artifical edges in the data, shades get
# dominated by them. Cutting off the largest and smallest 2% of
# shades helps
percentile2 = num.quantile(shad, 0.02)
percentile98 = num.quantile(shad, 0.98)
shad[shad > percentile98] = percentile98
shad[shad < percentile2] = percentile2
# normalize shading
shad -= num.nanmin(shad)
shad /= num.nanmax(shad)
if mask is not None:
shad[mask] = num.nan
# reduce range to balance gray color
shad *= shad_lim[1] - shad_lim[0]
shad += shad_lim[0]
# create ticks for plotting - real values for the labels
# and their position in normed data for the ticks
if data_limits is None:
data_max = num.nanmax(num.abs(displacement))
data_limits = (-data_max, data_max)
displ_min, displ_max = data_limits
# Combine color and shading
color_map = ScalarMappable(cmap=cmap)
color_map.set_clim(displ_min, displ_max)
rgb_map = color_map.to_rgba(displacement)
rgb_map[num.isnan(displacement)] = 1.
rgb_map *= shad[:, :, num.newaxis]
return rgb_map
def get_latlon_ratio(lat, lon):
"""
Get latlon ratio at given location
"""
dlat_meters = otd.distance_accurate50m(lat, lon, lat - 1., lon)
dlon_meters = otd.distance_accurate50m(lat, lon, lat, lon -1.)
return dlat_meters / dlon_meters
def plot_inset_hist(
axes, data, best_data, bbox_to_anchor, linewidth=0.5, labelsize=5,
cmap=None, cbounds=None, color='orange', alpha=0.4):
in_ax = inset_axes(
axes, width="100%", height="100%",
bbox_to_anchor=bbox_to_anchor,
bbox_transform=axes.transAxes, loc=2, borderpad=0)
histplot_op(
in_ax, data,
alpha=alpha, color=color, cmap=cmap, cbounds=cbounds, tstd=0.)
format_axes(in_ax)
format_axes(
in_ax, remove=['bottom'], visible=True,
linewidth=linewidth)
if best_data:
in_ax.axvline(
x=best_data,
color='red', lw=linewidth)
in_ax.tick_params(
axis='both', direction='in', labelsize=labelsize,
width=linewidth)
in_ax.yaxis.set_visible(False)
xticker = tick.MaxNLocator(nbins=2)
in_ax.xaxis.set_major_locator(xticker)
return in_ax
[docs]def scene_fits(problem, stage, plot_options):
"""
Plot geodetic data, synthetics and residuals.
"""
from pyrocko.dataset import gshhg
from kite.scene import Scene, UserIOWarning
from beat.colormap import roma_colormap
import gc
try:
homepath = problem.config.geodetic_config.types['SAR'].datadir
except KeyError:
raise ValueError('SAR data not in geodetic types!')
datatype = 'geodetic'
mode = problem.config.problem_config.mode
problem.init_hierarchicals()
fontsize = 10
fontsize_title = 12
ndmax = 3
nxmax = 3
# cmap = plt.cm.jet
# cmap = roma_colormap(256)
cmap = plt.cm.RdYlBu_r
po = plot_options
composite = problem.composites[datatype]
event = composite.event
try:
sources = composite.sources
ref_sources = None
except AttributeError:
logger.info('FFI scene fit, using reference source ...')
ref_sources = composite.config.gf_config.reference_sources
set_anchor(ref_sources, anchor='top')
fault = composite.load_fault_geometry()
sources = fault.get_all_subfaults(
datatype=datatype, component=composite.slip_varnames[0])
set_anchor(sources, anchor='top')
if po.reference:
if mode != ffi_mode_str:
composite.point2sources(po.reference)
ref_sources = copy.deepcopy(composite.sources)
bpoint = po.reference
else:
bpoint = get_result_point(stage.mtrace, po.post_llk)
bresults_tmp = composite.assemble_results(bpoint)
bvar_reductions = composite.get_variance_reductions(
bpoint, weights=composite.weights, results=bresults_tmp)
dataset_to_result = OrderedDict()
for dataset, bresult in zip(composite.datasets, bresults_tmp):
if dataset.typ == 'SAR':
dataset_to_result[dataset] = bresult
results = dataset_to_result.values()
dataset_index = dict(
(data, i) for (i, data) in enumerate(dataset_to_result.keys()))
nrmax = len(dataset_to_result.keys())
fullfig, restfig = utility.mod_i(nrmax, ndmax)
factors = num.ones(fullfig).tolist()
if restfig:
factors.append(float(restfig) / ndmax)
topo_plot_thresh = 300
if plot_options.nensemble > topo_plot_thresh:
logger.info(
'Plotting shaded relief as nensemble > %i.' % topo_plot_thresh)
show_topo = True
else:
logger.info(
'Not plotting shaded relief for nensemble < %i.' % (
topo_plot_thresh + 1))
show_topo = False
if po.nensemble > 1:
from tqdm import tqdm
logger.info(
'Collecting ensemble of %i '
'synthetic displacements ...' % po.nensemble)
nchains = len(stage.mtrace)
csteps = float(nchains) / po.nensemble
idxs = num.floor(num.arange(0, nchains, csteps)).astype('int32')
ens_results = []
points = []
ens_var_reductions = []
for idx in tqdm(idxs):
point = stage.mtrace.point(idx=idx)
points.append(point)
e_results = composite.assemble_results(point)
ens_results.append(e_results)
ens_var_reductions.append(
composite.get_variance_reductions(
point, weights=composite.weights, results=e_results))
all_var_reductions = {}
for dataset in dataset_to_result.keys():
target_var_reds = []
target_var_reds.append(bvar_reductions[dataset.name])
for var_reds in ens_var_reductions:
target_var_reds.append(var_reds[dataset.name])
all_var_reductions[dataset.name] = num.array(target_var_reds) * 100.
figures = []
axes = []
for f in factors:
figsize = list(mpl_papersize('a4', 'portrait'))
figsize[1] *= f
fig, ax = plt.subplots(
nrows=int(round(ndmax * f)), ncols=nxmax, figsize=figsize)
fig.tight_layout()
fig.subplots_adjust(
left=0.08,
right=1.0 - 0.03,
bottom=0.06,
top=1.0 - 0.06,
wspace=0.,
hspace=0.1)
figures.append(fig)
ax_a = num.atleast_2d(ax)
axes.append(ax_a)
nfigs = len(figures)
def axis_config(axes, source, scene, po):
latlon_ratio = get_latlon_ratio(source.lat, source.lon)
for i, ax in enumerate(axes):
if po.plot_projection == 'latlon':
ystr = 'Latitude [deg]'
xstr = 'Longitude [deg]'
if scene.frame.isDegree():
scale_x = {'scale': 1.}
scale_y = {'scale': 1.}
ax.set_aspect(latlon_ratio)
else:
scale_x = {'scale': otd.m2d}
scale_y = {'scale': otd.m2d}
ax.set_aspect('equal')
scale_x['offset'] = source.lon
scale_y['offset'] = source.lat
elif po.plot_projection == 'local':
ystr = 'Distance [km]'
xstr = 'Distance [km]'
if scene.frame.isDegree():
scale_x = {'scale': otd.d2m / km / latlon_ratio}
scale_y = {'scale': otd.d2m / km}
ax.set_aspect(latlon_ratio)
else:
scale_x = {'scale': 1. / km}
scale_y = {'scale': 1. / km}
ax.set_aspect('equal')
else:
raise TypeError(
'Plot projection %s not available' % po.plot_projection)
ax.xaxis.set_major_locator(tick.MaxNLocator(nbins=3))
ax.yaxis.set_major_locator(tick.MaxNLocator(nbins=3))
if i == 0:
ax.set_ylabel(ystr, fontsize=fontsize)
ax.set_xlabel(xstr, fontsize=fontsize)
ax.set_yticklabels(ax.get_yticklabels(), rotation=90)
ax.scale_x = scale_x
ax.scale_y = scale_y
scale_axes(ax.get_xaxis(), **scale_x)
scale_axes(ax.get_yaxis(), **scale_y)
if i > 0:
ax.set_yticklabels([])
ax.set_xticklabels([])
def draw_coastlines(ax, xlim, ylim, event, scene, po):
"""
xlim and ylim in Lon/Lat[deg]
"""
logger.debug('Drawing coastlines ...')
coasts = gshhg.GSHHG.full()
if po.plot_projection == 'latlon':
west, east = xlim
south, north = ylim
elif po.plot_projection == 'local':
lats, lons = otd.ne_to_latlon(
event.lat, event.lon,
north_m=num.array(ylim) * km, east_m=num.array(xlim) * km)
south, north = lats
west, east = lons
polygons = coasts.get_polygons_within(
west=west, east=east, south=south, north=north)
for p in polygons:
if (p.is_land() or p.is_antarctic_grounding_line() or
p.is_island_in_lake()):
if scene.frame.isMeter():
ys, xs = otd.latlon_to_ne_numpy(
event.lat, event.lon, p.lats, p.lons)
elif scene.frame.isDegree():
xs = p.lons - event.lon
ys = p.lats - event.lat
ax.plot(xs, ys, '-k', linewidth=0.5)
def add_arrow(ax, scene):
phi = num.nanmean(scene.phi)
theta = num.nanmean(scene.theta)
if theta == 0.: # MAI / az offsets
phi -= num.pi
los_dx = num.cos(phi + num.pi) * .0625
los_dy = num.sin(phi + num.pi) * .0625
az_dx = num.cos(phi - num.pi / 2) * .125
az_dy = num.sin(phi - num.pi / 2) * .125
anchor_x = .9 if los_dx < 0 else .1
anchor_y = .85 if los_dx < 0 else .975
if theta > 0.: # MAI / az offsets
az_arrow = FancyArrow(
x=anchor_x - az_dx, y=anchor_y - az_dy,
dx=az_dx, dy=az_dy,
head_width=.025,
alpha=.5, fc='k',
head_starts_at_zero=False,
length_includes_head=True,
transform=ax.transAxes)
ax.add_artist(az_arrow)
los_arrow = FancyArrow(
x=anchor_x - az_dx / 2, y=anchor_y - az_dy / 2,
dx=los_dx, dy=los_dy,
head_width=.02,
alpha=.5, fc='k',
head_starts_at_zero=False,
length_includes_head=True,
transform=ax.transAxes)
ax.add_artist(los_arrow)
def draw_leaves(ax, scene, offset_e=0, offset_n=0):
rects = scene.quadtree.getMPLRectangles()
for r in rects:
r.set_edgecolor((.4, .4, .4))
r.set_linewidth(.5)
r.set_facecolor('none')
r.set_x(r.get_x() - offset_e)
r.set_y(r.get_y() - offset_n)
map(ax.add_artist, rects)
ax.scatter(scene.quadtree.leaf_coordinates[:, 0] - offset_e,
scene.quadtree.leaf_coordinates[:, 1] - offset_n,
s=.25, c='black', alpha=.1)
def draw_sources(ax, sources, scene, po, event, **kwargs):
bgcolor = kwargs.pop('color', None)
for i, source in enumerate(sources):
if scene.frame.isMeter():
fn, fe = source.outline(cs='xy').T
elif scene.frame.isDegree():
fn, fe = source.outline(cs='latlon').T
fn -= event.lat
fe -= event.lon
if not bgcolor:
color = mpl_graph_color(i)
else:
color = bgcolor
if fn.size > 1:
alpha = 0.4
ax.plot(
fe, fn, '-',
linewidth=0.5, color=color, alpha=alpha, **kwargs)
ax.fill(
fe, fn,
edgecolor=color,
facecolor=light(color, .5), alpha=alpha)
ax.plot(
fe[0:2], fn[0:2], '-k', alpha=0.7,
linewidth=1.0)
else:
ax.plot(
fe, fn, marker='*',
markersize=10, color=color, **kwargs)
def cbtick(x):
rx = math.floor(x * 1000.) / 1000.
return [-rx, rx]
colims = [num.max([
num.max(num.abs(r.processed_obs)),
num.max(num.abs(r.processed_syn))]) for r in results]
dcolims = [num.max(num.abs(r.processed_res)) for r in results]
import string
for idata, (dataset, result) in enumerate(dataset_to_result.items()):
subplot_letter = string.ascii_lowercase[idata]
try:
scene_path = os.path.join(homepath, dataset.name)
logger.info(
'Loading full resolution kite scene: %s' % scene_path)
scene = Scene.load(scene_path)
except UserIOWarning:
logger.warning(
'Full resolution data could not be loaded! Skipping ...')
continue
if scene.frame.isMeter():
offset_n, offset_e = map(float, otd.latlon_to_ne_numpy(
scene.frame.llLat, scene.frame.llLon,
event.lat, event.lon))
elif scene.frame.isDegree():
offset_n = event.lat - scene.frame.llLat
offset_e = event.lon - scene.frame.llLon
im_extent = (scene.frame.E.min() - offset_e,
scene.frame.E.max() - offset_e,
scene.frame.N.min() - offset_n,
scene.frame.N.max() - offset_n)
urE, urN, llE, llN = (0., 0., 0., 0.)
turE, turN, tllE, tllN = zip(
*[(l.gridE.max() - offset_e,
l.gridN.max() - offset_n,
l.gridE.min() - offset_e,
l.gridN.min() - offset_n)
for l in scene.quadtree.leaves])
turE, turN = map(max, (turE, turN))
tllE, tllN = map(min, (tllE, tllN))
urE, urN = map(max, ((turE, urE), (urN, turN)))
llE, llN = map(min, ((tllE, llE), (llN, tllN)))
lat, lon = otd.ne_to_latlon(
sources[0].lat, sources[0].lon,
num.array([llN, urN]), num.array([llE, urE]))
# result = dataset_to_result[dataset]
tidx = dataset_index[dataset]
figidx, rowidx = utility.mod_i(tidx, ndmax)
axs = axes[figidx][rowidx, :]
imgs = []
for ax, data_str in zip(axs, ['obs', 'syn', 'res']):
logger.info('Plotting %s' % data_str)
datavec = getattr(result, 'processed_%s' % data_str)
if data_str == 'res' and po.plot_projection == 'local':
vmin = -dcolims[tidx]
vmax = dcolims[tidx]
else:
vmin = -colims[tidx]
vmax = colims[tidx]
data = map_displacement_grid(datavec, scene)
if show_topo:
elevation = scene.get_elevation()
elevation_mask = num.where(elevation == 0., True, False)
data = shaded_displacements(
data, elevation, cmap,
shad_lim=(0.4, .99), contrast=1., mask=elevation_mask,
data_limits=(vmin, vmax))
imgs.append(
ax.imshow(
data, extent=im_extent, cmap=cmap,
vmin=vmin, vmax=vmax,
origin='lower', interpolation='nearest'))
ax.set_xlim(llE, urE)
ax.set_ylim(llN, urN)
draw_leaves(ax, scene, offset_e, offset_n)
draw_coastlines(
ax, lon, lat, event, scene, po)
if po.nensemble > 1:
in_ax = plot_inset_hist(
axs[2],
data=pmp.utils.make_2d(all_var_reductions[dataset.name]),
best_data=bvar_reductions[dataset.name] * 100.,
linewidth=1.,
bbox_to_anchor=(0.75, .775, .25, .225),
labelsize=6)
format_axes(
in_ax, remove=['left', 'bottom'], visible=True,
linewidth=0.75)
in_ax.set_xlabel('VR [%]', fontsize=fontsize - 3)
fontdict = {
'fontsize': fontsize,
'fontweight': 'bold',
'verticalalignment': 'top'}
transform = axes[figidx][rowidx, 0].transAxes
if dataset.name[-5::] == 'dscxn':
title = 'descending'
elif dataset.name[-5::] == 'ascxn':
title = 'ascending'
else:
title = dataset.name
axes[figidx][rowidx, 0].text(
.025, 1.025, '({}) {}'.format(subplot_letter, title),
fontsize=fontsize_title, alpha=1.,
va='bottom', transform=transform)
for i, quantity in enumerate(['data', 'model', 'residual']):
transform = axes[figidx][rowidx, i].transAxes
axes[figidx][rowidx, i].text(
0.5, 0.95, quantity, fontdict, transform=transform,
horizontalalignment='center')
draw_sources(
axes[figidx][rowidx, 1], sources, scene, po, event=event)
if ref_sources:
ref_color = scolor('aluminium4')
logger.info('Plotting reference sources')
draw_sources(
axes[figidx][rowidx, 1],
ref_sources, scene, po, event=event, color=ref_color)
f = factors[figidx]
if f > 2. / 3:
cbb = (0.68 - (0.3075 * rowidx))
elif f > 1. / 2:
cbb = (0.53 - (0.47 * rowidx))
elif f > 1. / 4:
cbb = (0.06)
cbl = 0.46
cbw = 0.15
cbh = 0.01
cbaxes = figures[figidx].add_axes([cbl, cbb, cbw, cbh])
cblabel = 'LOS displacement [m]'
cbs = plt.colorbar(
imgs[1],
ax=axes[figidx][rowidx, 0],
ticks=cbtick(colims[tidx]),
cax=cbaxes,
orientation='horizontal',
cmap=cmap)
cbs.set_label(cblabel, fontsize=fontsize)
if po.plot_projection == 'local':
dcbaxes = figures[figidx].add_axes([cbl + 0.3, cbb, cbw, cbh])
cbr = plt.colorbar(
imgs[2],
ax=axes[figidx][rowidx, 2],
ticks=cbtick(dcolims[tidx]),
cax=dcbaxes,
orientation='horizontal',
cmap=cmap)
cbr.set_label(cblabel, fontsize=fontsize)
axis_config(axes[figidx][rowidx, :], event, scene, po)
add_arrow(axes[figidx][rowidx, 0], scene)
del scene
gc.collect()
return figures
def draw_scene_fits(problem, plot_options):
if 'geodetic' not in list(problem.composites.keys()):
raise TypeError('No geodetic composite defined in the problem!')
if 'SAR' not in problem.config.geodetic_config.types:
raise TypeError('There is no SAR data in the problem setup!')
logger.info('Drawing SAR misfits ...')
po = plot_options
stage = Stage(homepath=problem.outfolder,
backend=problem.config.sampler_config.backend)
if not po.reference:
stage.load_results(
varnames=problem.varnames,
model=problem.model, stage_number=po.load_stage,
load='trace', chains=[-1])
llk_str = po.post_llk
else:
llk_str = 'ref'
mode = problem.config.problem_config.mode
outpath = os.path.join(
problem.config.project_dir,
mode, po.figure_dir, 'scenes_%s_%s_%s_%i' % (
stage.number, llk_str, po.plot_projection, po.nensemble))
if not os.path.exists(outpath) or po.force:
figs = scene_fits(problem, stage, po)
else:
logger.info('scene plots exist. Use force=True for replotting!')
return
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figures to %s' % outpath)
if po.outformat == 'pdf':
with PdfPages(outpath + '.pdf') as opdf:
for fig in figs:
opdf.savefig(fig)
else:
for i, fig in enumerate(figs):
fig.savefig(
'%s_%i.%s' % (outpath, i, po.outformat), dpi=po.dpi)
def draw_gnss_fits(problem, plot_options):
if 'geodetic' not in list(problem.composites.keys()):
raise TypeError('No geodetic composite defined in the problem!')
if 'GNSS' not in problem.config.geodetic_config.types:
raise TypeError('There is no GNSS data in the problem setup!')
logger.info('Drawing GNSS misfits ...')
po = plot_options
stage = Stage(homepath=problem.outfolder,
backend=problem.config.sampler_config.backend)
if not po.reference:
stage.load_results(
varnames=problem.varnames,
model=problem.model, stage_number=po.load_stage,
load='trace', chains=[-1])
llk_str = po.post_llk
else:
llk_str = 'ref'
mode = problem.config.problem_config.mode
outpath = os.path.join(
problem.config.project_dir,
mode, po.figure_dir, 'gnss_%s_%s_%i_%s' % (
stage.number, llk_str, po.nensemble, po.plot_projection))
if not os.path.exists(outpath) or po.force:
figs = gnss_fits(problem, stage, po)
else:
logger.info('scene plots exist. Use force=True for replotting!')
return
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figures to %s' % outpath)
for component, fig in zip(('horizontal', 'vertical'), figs):
fig.save(outpath + '_%s.%s' % (
component, po.outformat), resolution=po.dpi)
def extract_time_shifts(point, wmap):
try:
time_shifts = point[wmap.time_shifts_id][
wmap.station_correction_idxs]
except KeyError:
raise ValueError(
'Sampling results do not contain time-shifts for wmap'
' %s!' % wmap.time_shifts_id)
return time_shifts
[docs]def seismic_fits(problem, stage, plot_options):
"""
Modified from grond. Plot synthetic and data waveforms and the misfit for
the selected posterior model.
"""
def plot_trace(axes, tr, **kwargs):
return axes.plot(tr.get_xdata(), tr.get_ydata(), **kwargs)
def plot_taper(axes, t, taper, mode='geometry', **kwargs):
y = num.ones(t.size) * 0.9
if mode == 'geometry':
taper(y, t[0], t[1] - t[0])
y2 = num.concatenate((y, -y[::-1]))
t2 = num.concatenate((t, t[::-1]))
axes.fill(t2, y2, **kwargs)
def plot_dtrace(axes, tr, space, mi, ma, **kwargs):
t = tr.get_xdata()
y = tr.get_ydata()
y2 = (num.concatenate((y, num.zeros(y.size))) - mi) / \
(ma - mi) * space - (1.0 + space)
t2 = num.concatenate((t, t[::-1]))
axes.fill(
t2, y2,
clip_on=False,
**kwargs)
def plot_inset_hist(
axes, data, best_data, bbox_to_anchor,
cmap=None, cbounds=None, color='orange', alpha=0.4):
in_ax = inset_axes(
axes, width="100%", height="100%",
bbox_to_anchor=bbox_to_anchor,
bbox_transform=axes.transAxes, loc=2, borderpad=0)
histplot_op(
in_ax, data,
alpha=alpha, color=color, cmap=cmap, cbounds=cbounds, tstd=0.)
format_axes(in_ax)
linewidth = 0.5
format_axes(
in_ax, remove=['bottom'], visible=True,
linewidth=linewidth)
if best_data:
in_ax.axvline(
x=best_data,
color='red', lw=linewidth)
in_ax.tick_params(
axis='both', direction='in', labelsize=5, width=linewidth)
in_ax.tick_params(top=False)
in_ax.yaxis.set_visible(False)
xticker = tick.MaxNLocator(nbins=2)
in_ax.xaxis.set_major_locator(xticker)
return in_ax
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
composite = problem.composites['seismic']
fontsize = 8
fontsize_title = 10
target_index = dict(
(target, i) for (i, target) in enumerate(composite.targets))
po = plot_options
if not po.reference:
best_point = get_result_point(stage.mtrace, po.post_llk)
else:
best_point = po.reference
try:
composite.point2sources(best_point)
source = composite.sources[0]
chop_bounds = ['a', 'd']
except AttributeError:
logger.info('FFI waveform fit, using reference source ...')
source = composite.config.gf_config.reference_sources[0]
source.time = composite.event.time
chop_bounds = ['b', 'c']
if best_point: # for source individual contributions
bresults = composite.assemble_results(
best_point, outmode='tapered_data', chop_bounds=chop_bounds)
synth_plot_flag = True
else:
# get dummy results for data
logger.warning(
'Got "None" post_llk, still loading MAP for VR calculation')
best_point = get_result_point(stage.mtrace, 'max')
bresults = composite.assemble_results(
best_point, chop_bounds=chop_bounds)
synth_plot_flag = False
composite.analyse_noise(best_point, chop_bounds=chop_bounds)
composite.update_weights(best_point, chop_bounds=chop_bounds)
if plot_options.nensemble > 1:
from tqdm import tqdm
logger.info(
'Collecting ensemble of %i synthetic waveforms ...' % po.nensemble)
nchains = len(stage.mtrace)
csteps = float(nchains) / po.nensemble
idxs = num.floor(num.arange(0, nchains, csteps)).astype('int32')
ens_results = []
points = []
ens_var_reductions = []
for idx in tqdm(idxs):
point = stage.mtrace.point(idx=idx)
points.append(point)
results = composite.assemble_results(
point, chop_bounds=chop_bounds)
ens_results.append(results)
ens_var_reductions.append(
composite.get_variance_reductions(
point, weights=composite.weights,
results=results, chop_bounds=chop_bounds))
bvar_reductions = composite.get_variance_reductions(
best_point, weights=composite.weights,
results=bresults, chop_bounds=chop_bounds)
# collecting results for targets
logger.info('Mapping results to targets ...')
target_to_results = {}
all_syn_trs_target = {}
all_var_reductions = {}
dtraces = []
for target in composite.targets:
target_results = []
target_synths = []
target_var_reductions = []
i = target_index[target]
nslc_id_str = utility.list2string(target.codes)
target_results.append(bresults[i])
target_synths.append(bresults[i].processed_syn)
target_var_reductions.append(
bvar_reductions[nslc_id_str])
dtraces.append(copy.deepcopy(bresults[i].processed_res))
if plot_options.nensemble > 1:
for results, var_reductions in zip(
ens_results, ens_var_reductions):
# put all results per target here not only single
target_results.append(results[i])
target_synths.append(results[i].processed_syn)
target_var_reductions.append(
var_reductions[nslc_id_str])
target_to_results[target] = target_results
all_syn_trs_target[target] = target_synths
all_var_reductions[target] = num.array(target_var_reductions) * 100.
# collecting time-shifts:
station_corr = composite.config.station_corrections
if station_corr:
tshifts = problem.config.problem_config.hierarchicals['time_shift']
time_shift_bounds = [tshifts.lower, tshifts.upper]
logger.info('Collecting time-shifts ...')
if plot_options.nensemble > 1:
ens_time_shifts = []
for point in points:
comp_time_shifts = []
for wmap in composite.wavemaps:
comp_time_shifts.append(
extract_time_shifts(point, wmap))
ens_time_shifts.append(
num.hstack(comp_time_shifts))
btime_shifts = num.hstack(
[extract_time_shifts(best_point, wmap)
for wmap in composite.wavemaps])
logger.info('Mapping time-shifts to targets ...')
all_time_shifts = {}
for target in composite.targets:
target_time_shifts = []
i = target_index[target]
target_time_shifts.append(btime_shifts[i])
if plot_options.nensemble > 1:
for time_shifts in ens_time_shifts:
target_time_shifts.append(time_shifts[i])
all_time_shifts[target] = num.array(target_time_shifts)
skey = lambda tr: tr.channel
# trace_minmaxs = trace.minmax(all_syn_trs, skey)
dminmaxs = trace.minmax(dtraces, skey)
for tr in dtraces:
if tr:
dmin, dmax = dminmaxs[skey(tr)]
tr.ydata /= max(abs(dmin), abs(dmax))
cg_to_targets = utility.gather(
composite.targets,
lambda t: t.codes[3],
filter=lambda t: t in target_to_results)
cgs = cg_to_targets.keys()
figs = []
logger.info('Plotting waveforms ...')
for cg in cgs:
targets = cg_to_targets[cg]
# can keep from here ... until
nframes = len(targets)
nx = int(math.ceil(math.sqrt(nframes)))
ny = (nframes - 1) // nx + 1
logger.debug('nx %i, ny %i' % (nx, ny))
nxmax = 4
nymax = 4
nxx = (nx - 1) // nxmax + 1
nyy = (ny - 1) // nymax + 1
xs = num.arange(nx) // ((max(2, nx) - 1.0) / 2.)
ys = num.arange(ny) // ((max(2, ny) - 1.0) / 2.)
xs -= num.mean(xs)
ys -= num.mean(ys)
fxs = num.tile(xs, ny)
fys = num.repeat(ys, nx)
data = []
for target in targets:
azi = source.azibazi_to(target)[0]
dist = source.distance_to(target)
x = dist * num.sin(num.deg2rad(azi))
y = dist * num.cos(num.deg2rad(azi))
data.append((x, y, dist))
gxs, gys, dists = num.array(data, dtype=num.float).T
iorder = num.argsort(dists)
gxs = gxs[iorder]
gys = gys[iorder]
targets_sorted = [targets[ii] for ii in iorder]
gxs -= num.mean(gxs)
gys -= num.mean(gys)
gmax = max(num.max(num.abs(gys)), num.max(num.abs(gxs)))
if gmax == 0.:
gmax = 1.
gxs /= gmax
gys /= gmax
dists = num.sqrt(
(fxs[num.newaxis, :] - gxs[:, num.newaxis]) ** 2 +
(fys[num.newaxis, :] - gys[:, num.newaxis]) ** 2)
distmax = num.max(dists)
availmask = num.ones(dists.shape[1], dtype=num.bool)
frame_to_target = {}
for itarget, target in enumerate(targets_sorted):
iframe = num.argmin(
num.where(availmask, dists[itarget], distmax + 1.))
availmask[iframe] = False
iy, ix = num.unravel_index(iframe, (ny, nx))
frame_to_target[iy, ix] = target
figures = {}
for iy in range(ny):
for ix in range(nx):
if (iy, ix) not in frame_to_target:
continue
ixx = ix // nxmax
iyy = iy // nymax
if (iyy, ixx) not in figures:
figures[iyy, ixx] = plt.figure(
figsize=mpl_papersize('a4', 'landscape'))
figures[iyy, ixx].subplots_adjust(
left=0.03,
right=1.0 - 0.03,
bottom=0.03,
top=1.0 - 0.06,
wspace=0.2,
hspace=0.2)
figs.append(figures[iyy, ixx])
logger.debug('iyy %i, ixx %i' % (iyy, ixx))
logger.debug('iy %i, ix %i' % (iy, ix))
fig = figures[iyy, ixx]
target = frame_to_target[iy, ix]
# get min max of all traces
key = target.codes[3]
amin, amax = trace.minmax(
all_syn_trs_target[target],
key=skey)[key]
# need target specific minmax
absmax = max(abs(amin), abs(amax))
ny_this = nymax # min(ny, nymax)
nx_this = nxmax # min(nx, nxmax)
i_this = (iy % ny_this) * nx_this + (ix % nx_this) + 1
logger.debug('i_this %i' % i_this)
logger.debug('Station {}'.format(
utility.list2string(target.codes)))
axes2 = fig.add_subplot(ny_this, nx_this, i_this)
space = 0.5
space_factor = 1.0 + space
axes2.set_axis_off()
axes2.set_ylim(-1.05 * space_factor, 1.05)
axes = axes2.twinx()
axes.set_axis_off()
ymin, ymax = - absmax * 1.33 * space_factor, absmax * 1.33
try:
axes.set_ylim(ymin, ymax)
except ValueError:
logger.debug(
'These traces contain NaN or Inf open in snuffler?')
input('Press enter! Otherwise Ctrl + C')
from pyrocko.trace import snuffle
snuffle(all_syn_trs_target[target])
itarget = target_index[target]
result = bresults[itarget]
traces = all_syn_trs_target[target]
dtrace = dtraces[itarget]
if po.nensemble > 1:
xmin, xmax = trace.minmaxtime(traces, key=skey)[key]
fuzzy_waveforms(
axes, traces, linewidth=7, zorder=0,
grid_size=(500, 500), alpha=1.0)
tap_color_annot = (0.35, 0.35, 0.25)
tap_color_edge = (0.85, 0.85, 0.80)
# tap_color_fill = (0.95, 0.95, 0.90)
plot_taper(
axes2, result.processed_obs.get_xdata(), result.taper,
mode=composite._mode, fc='None', ec=tap_color_edge,
zorder=4, alpha=0.6)
time_shift_color = scolor('aluminium3')
obs_color = scolor('aluminium5')
syn_color = scolor('scarletred2')
misfit_color = scolor('scarletred2')
if synth_plot_flag:
# only draw if highlighted point exists
plot_dtrace(
axes2, dtrace, space, 0., 1.,
fc=light(misfit_color, 0.3),
ec=misfit_color, zorder=4)
if po.plot_projection == 'individual':
for i, tr in enumerate(result.source_contributions):
plot_trace(
axes, tr,
color=mpl_graph_color(i), lw=0.5, zorder=5)
else:
plot_trace(
axes, result.processed_syn,
color=syn_color, lw=0.5, zorder=5)
plot_trace(
axes, result.processed_obs,
color=obs_color, lw=0.5, zorder=5)
xdata = result.processed_obs.get_xdata()
axes.set_xlim(xdata[0], xdata[-1])
tmarks = [
result.processed_obs.tmin,
result.processed_obs.tmax]
tmark_fontsize = fontsize - 1
# plot variance reductions
if po.nensemble > 1:
nslc_id_str = utility.list2string(target.codes)
logger.debug(
'Plotting variance reductions for %s' % nslc_id_str)
if synth_plot_flag:
best_data = bvar_reductions[nslc_id_str] * 100.
else: # for None post_llk
best_data = None
in_ax = plot_inset_hist(
axes,
data=pmp.utils.make_2d(all_var_reductions[target]),
best_data=best_data,
bbox_to_anchor=(0.9, .75, .2, .2))
in_ax.set_title('VR [%]', fontsize=5)
if station_corr:
sidebar_ybounds = [-0.9, -1.3]
ytmarks = [-1.3, -1.3]
hor_alignment = 'center'
if synth_plot_flag:
best_data = btime_shifts[itarget]
else: # for None post_llk
best_data = None
if po.nensemble > 1:
in_ax = plot_inset_hist(
axes,
data=pmp.utils.make_2d(all_time_shifts[target]),
best_data=best_data,
bbox_to_anchor=(-0.0985, .26, .2, .2),
# cmap=plt.cm.get_cmap('seismic'),
# cbounds=time_shift_bounds,
color=time_shift_color,
alpha=0.7)
in_ax.set_xlim(*time_shift_bounds)
else:
sidebar_ybounds = [-1.2, -1.2]
ytmarks = [-1.2, -1.2]
hor_alignment = 'left'
for tmark, ybound in zip(tmarks, sidebar_ybounds):
axes2.plot(
[tmark, tmark], [ybound, 0.1], color=tap_color_annot)
for xtmark, ytmark, text, ha, va in [
(tmarks[0], ytmarks[0],
'$\,$ ' + str_duration(tmarks[0] - source.time),
hor_alignment,
'bottom'),
(tmarks[1], ytmarks[1],
'$\Delta$ ' + str_duration(tmarks[1] - tmarks[0]),
'right',
'bottom')]:
axes2.annotate(
text,
xy=(xtmark, ytmark),
xycoords='data',
xytext=(
fontsize * 0.4 * [-1, 1][ha == 'left'],
fontsize * 0.2),
textcoords='offset points',
ha=ha,
va=va,
color=tap_color_annot,
fontsize=tmark_fontsize, zorder=10)
scale_string = None
infos = []
if scale_string:
infos.append(scale_string)
infos.append('.'.join(x for x in target.codes if x))
dist = source.distance_to(target)
azi = source.azibazi_to(target)[0]
infos.append(str_dist(dist))
infos.append('%.0f\u00B0' % azi)
# infos.append('%.3f' % gcms[itarget])
axes2.annotate(
'\n'.join(infos),
xy=(0., 1.),
xycoords='axes fraction',
xytext=(1., 1.),
textcoords='offset points',
ha='left',
va='top',
fontsize=fontsize,
fontstyle='normal', zorder=10)
# annotate axis amplitude
axes.annotate(
'%0.3g %s -' % (-absmax, str_unit(target.quantity)),
xycoords='data',
xy=(tmarks[1], -absmax),
xytext=(1., 1.),
textcoords='offset points',
ha='right',
va='center',
fontsize=fontsize - 3,
color=obs_color,
fontstyle='normal')
axes2.set_zorder(10)
for (iyy, ixx), fig in figures.items():
title = '.'.join(x for x in cg if x)
if len(figures) > 1:
title += ' (%i/%i, %i/%i)' % (iyy + 1, nyy, ixx + 1, nxx)
fig.suptitle(title, fontsize=fontsize_title)
return figs
def draw_seismic_fits(problem, po):
if 'seismic' not in list(problem.composites.keys()):
raise TypeError('No seismic composite defined for this problem!')
logger.info('Drawing Waveform fits ...')
stage = Stage(homepath=problem.outfolder,
backend=problem.config.sampler_config.backend)
mode = problem.config.problem_config.mode
if not po.reference:
llk_str = po.post_llk
stage.load_results(
varnames=problem.varnames,
model=problem.model, stage_number=po.load_stage,
load='trace', chains=[-1])
else:
llk_str = 'ref'
outpath = os.path.join(
problem.config.project_dir,
mode, po.figure_dir, 'waveforms_%s_%s_%i' % (
stage.number, llk_str, po.nensemble))
if not os.path.exists(outpath) or po.force:
figs = seismic_fits(problem, stage, po)
else:
logger.info('waveform plots exist. Use force=True for replotting!')
return
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figures to %s' % outpath)
if po.outformat == 'pdf':
with PdfPages(outpath + '.pdf') as opdf:
for fig in figs:
opdf.savefig(fig)
else:
for i, fig in enumerate(figs):
fig.savefig(outpath + '_%i.%s' % (i, po.outformat), dpi=po.dpi)
def point2array(point, varnames, rpoint=None):
"""
Concatenate values of point according to order of given varnames.
"""
if point is not None:
array = num.empty((len(varnames)), dtype='float64')
for i, varname in enumerate(varnames):
try:
array[i] = point[varname].ravel()
except KeyError: # in case fixed variable
if rpoint:
array[i] = rpoint[varname].ravel()
else:
raise ValueError(
'Fixed Component "%s" no fixed value given!' % varname)
return array
else:
return None
def extract_mt_components(problem, po, include_magnitude=False):
"""
Extract Moment Tensor components from problem results for plotting.
"""
source_type = problem.config.problem_config.source_type
if source_type in ['MTSource', 'MTQTSource']:
varnames = ['mnn', 'mee', 'mdd', 'mne', 'mnd', 'med']
elif source_type == 'DCSource':
varnames = ['strike', 'dip', 'rake']
else:
raise ValueError(
'Plot is only supported for point "MTSource" and "DCSource"')
if include_magnitude:
varnames += ['magnitude']
if not po.reference:
rpoint = None
llk_str = po.post_llk
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
n_mts = len(stage.mtrace)
m6s = num.empty((n_mts, len(varnames)), dtype='float64')
for i, varname in enumerate(varnames):
try:
m6s[:, i] = stage.mtrace.get_values(
varname, combine=True, squeeze=True).ravel()
except ValueError: # if fixed value add that to the ensemble
rpoint = problem.get_random_point()
mtfield = num.full_like(
num.empty((n_mts), dtype=num.float64), rpoint[varname])
m6s[:, i] = mtfield
if po.nensemble:
logger.info(
'Drawing %i solutions from ensemble ...' % po.nensemble)
csteps = float(n_mts) / po.nensemble
idxs = num.floor(
num.arange(0, n_mts, csteps)).astype('int32')
m6s = m6s[idxs, :]
else:
logger.info('Drawing full ensemble ...')
point = get_result_point(stage.mtrace, po.post_llk)
best_mt = point2array(point, varnames=varnames, rpoint=rpoint)
else:
llk_str = 'ref'
m6s = [point2array(point=po.reference, varnames=varnames)]
best_mt = None
return m6s, best_mt, llk_str
def draw_fuzzy_beachball(problem, po):
if problem.config.problem_config.n_sources > 1:
raise NotImplementedError(
'Fuzzy beachball is not yet implemented for more than one source!')
if po.load_stage is None:
po.load_stage = -1
m6s, best_mt, llk_str = extract_mt_components(problem, po)
logger.info('Drawing Fuzzy Beachball ...')
kwargs = {
'beachball_type': 'full',
'size': 8,
'size_units': 'data',
'position': (5, 5),
'color_t': 'black',
'edgecolor': 'black',
'grid_resolution': 400}
fig = plt.figure(figsize=(4., 4.))
fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
axes = fig.add_subplot(1, 1, 1)
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'fuzzy_beachball_%i_%s_%i.%s' % (
po.load_stage, llk_str, po.nensemble, po.outformat))
if not os.path.exists(outpath) or po.force or po.outformat == 'display':
beachball.plot_fuzzy_beachball_mpl_pixmap(
m6s, axes, best_mt=best_mt, best_color='red', **kwargs)
axes.set_xlim(0., 10.)
axes.set_ylim(0., 10.)
axes.set_axis_off()
if not po.outformat == 'display':
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, dpi=po.dpi)
else:
plt.show()
else:
logger.info('Plot already exists! Please use --force to overwrite!')
def fuzzy_mt_decomposition(
axes, list_m6s,
labels=None, colors=None, fontsize=12):
"""
Plot fuzzy moment tensor decompositions for list of mt ensembles.
"""
from pymc3 import quantiles
logger.info('Drawing Fuzzy MT Decomposition ...')
# beachball kwargs
kwargs = {
'beachball_type': 'full',
'size': 1.,
'size_units': 'data',
'edgecolor': 'black',
'linewidth': 1,
'grid_resolution': 200}
def get_decomps(source_vals):
isos = []
dcs = []
clvds = []
devs = []
tots = []
for val in source_vals:
m = mt.MomentTensor.from_values(val)
iso, dc, clvd, dev, tot = m.standard_decomposition()
isos.append(iso)
dcs.append(dc)
clvds.append(clvd)
devs.append(dev)
tots.append(tot)
return isos, dcs, clvds, devs, tots
yscale = 1.3
nlines = len(list_m6s)
nlines_max = nlines * yscale
if colors is None:
colors = nlines * [None]
if labels is None:
labels = ['Ensemble'] + ([None] * (nlines - 1))
lines = []
for i, (label, m6s, color) in enumerate(zip(labels, list_m6s, colors)):
if color is None:
color = mpl_graph_color(i)
lines.append(
(label, m6s, color))
magnitude_full_max = max(m6s.mean(axis=0)[-1] for (_, m6s, _) in lines)
for xpos, label in [
(0., 'Full'),
(2., 'Isotropic'),
(4., 'Deviatoric'),
(6., 'CLVD'),
(8., 'DC')]:
axes.annotate(
label,
xy=(1 + xpos, nlines_max),
xycoords='data',
xytext=(0., 0.),
textcoords='offset points',
ha='center',
va='center',
color='black',
fontsize=fontsize)
for i, (label, m6s, color_t) in enumerate(lines):
ypos = nlines_max - (i * yscale) - 1.0
mean_magnitude = m6s.mean(0)[-1]
size0 = mean_magnitude / magnitude_full_max
isos, dcs, clvds, devs, tots = get_decomps(m6s)
axes.annotate(
label,
xy=(-2., ypos),
xycoords='data',
xytext=(0., 0.),
textcoords='offset points',
ha='left',
va='center',
color='black',
fontsize=fontsize)
for xpos, decomp, ops in [
(0., tots, '-'),
(2., isos, '='),
(4., devs, '='),
(6., clvds, '+'),
(8., dcs, None)]:
ratios = num.array([comp[1] for comp in decomp])
ratio = ratios.mean()
ratios_diff = ratios.max() - ratios.min()
ratios_qu = quantiles(ratios * 100.)
mt_parts = [comp[2] for comp in decomp]
if ratio > 1e-4:
try:
size = math.sqrt(ratio) * 0.95 * size0
kwargs['position'] = (1. + xpos, ypos)
kwargs['size'] = size
kwargs['color_t'] = color_t
beachball.plot_fuzzy_beachball_mpl_pixmap(
mt_parts, axes, best_mt=None, **kwargs)
if ratios_diff > 0.:
label = '{:03.1f}-{:03.1f}%'.format(
ratios_qu[2.5], ratios_qu[97.5])
else:
label = '{:03.1f}%'.format(ratios_qu[2.5])
axes.annotate(
label,
xy=(1. + xpos, ypos - 0.65),
xycoords='data',
xytext=(0., 0.),
textcoords='offset points',
ha='center',
va='center',
color='black',
fontsize=fontsize - 2)
except beachball.BeachballError as e:
logger.warn(str(e))
axes.annotate(
'ERROR',
xy=(1. + xpos, ypos),
ha='center',
va='center',
color='red',
fontsize=fontsize)
else:
axes.annotate(
'N/A',
xy=(1. + xpos, ypos),
ha='center',
va='center',
color='black',
fontsize=fontsize)
label = '{:03.1f}%'.format(0.)
axes.annotate(
label,
xy=(1. + xpos, ypos - 0.65),
xycoords='data',
xytext=(0., 0.),
textcoords='offset points',
ha='center',
va='center',
color='black',
fontsize=fontsize - 2)
if ops is not None:
axes.annotate(
ops,
xy=(2. + xpos, ypos),
ha='center',
va='center',
color='black',
fontsize=fontsize)
axes.axison = False
axes.set_xlim(-2.25, 9.75)
axes.set_ylim(-0.7, nlines_max + 0.5)
axes.set_axis_off()
def draw_fuzzy_mt_decomposition(problem, po):
fontsize = 10
if problem.config.problem_config.n_sources > 1:
raise NotImplementedError(
'Fuzzy MT decomposition is not yet'
'implemented for more than one source!')
if po.load_stage is None:
po.load_stage = -1
m6s, _, llk_str = extract_mt_components(
problem, po, include_magnitude=True)
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'fuzzy_mt_decomposition_%i_%s_%i.%s' % (
po.load_stage, llk_str, po.nensemble, po.outformat))
if not os.path.exists(outpath) or po.force or po.outformat == 'display':
fig = plt.figure(figsize=(6., 2.))
fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
axes = fig.add_subplot(1, 1, 1)
fuzzy_mt_decomposition(axes, list_m6s=[m6s], fontsize=fontsize)
if not po.outformat == 'display':
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, dpi=po.dpi)
else:
plt.show()
else:
logger.info('Plot already exists! Please use --force to overwrite!')
def draw_hudson(problem, po):
"""
Modified from grond. Plot the hudson graph for the reference event(grey)
and the best solution (red beachball).
Also a random number of models from the
selected stage are plotted as smaller beachballs on the hudson graph.
"""
from pyrocko.plot import beachball, hudson
from pyrocko import moment_tensor as mtm
from numpy import random
if problem.config.problem_config.n_sources > 1:
raise NotImplementedError(
'Hudson plot is not yet implemented for more than one source!')
if po.load_stage is None:
po.load_stage = -1
m6s, best_mt, llk_str = extract_mt_components(problem, po)
logger.info('Drawing Hudson plot ...')
fontsize = 12
beachball_type = 'full'
color = 'red'
markersize = fontsize * 1.5
markersize_small = markersize * 0.2
beachballsize = markersize
beachballsize_small = beachballsize * 0.5
fig = plt.figure(figsize=(4., 4.))
fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
axes = fig.add_subplot(1, 1, 1)
hudson.draw_axes(axes)
data = []
for m6 in m6s:
mt = mtm.as_mt(m6)
u, v = hudson.project(mt)
if random.random() < 0.05:
try:
beachball.plot_beachball_mpl(
mt, axes,
beachball_type=beachball_type,
position=(u, v),
size=beachballsize_small,
color_t='black',
alpha=0.5,
zorder=1,
linewidth=0.25)
except beachball.BeachballError as e:
logger.warn(str(e))
else:
data.append((u, v))
if data:
u, v = num.array(data).T
axes.plot(
u, v, 'o',
color=color,
ms=markersize_small,
mec='none',
mew=0,
alpha=0.25,
zorder=0)
if best_mt is not None:
mt = mtm.as_mt(best_mt)
u, v = hudson.project(mt)
try:
beachball.plot_beachball_mpl(
mt, axes,
beachball_type=beachball_type,
position=(u, v),
size=beachballsize,
color_t=color,
alpha=0.5,
zorder=2,
linewidth=0.25)
except beachball.BeachballError as e:
logger.warn(str(e))
if isinstance(problem.event.moment_tensor, mtm.MomentTensor):
mt = problem.event.moment_tensor
u, v = hudson.project(mt)
if not po.reference:
try:
beachball.plot_beachball_mpl(
mt, axes,
beachball_type=beachball_type,
position=(u, v),
size=beachballsize,
color_t='grey',
alpha=0.5,
zorder=2,
linewidth=0.25)
logger.info('drawing reference event in grey ...')
except beachball.BeachballError as e:
logger.warn(str(e))
else:
logger.info(
'No reference event moment tensor information given, '
'skipping drawing ...')
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'hudson_%i_%s_%i.%s' % (
po.load_stage, llk_str, po.nensemble, po.outformat))
if not os.path.exists(outpath) or po.force or po.outformat == 'display':
if not po.outformat == 'display':
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, dpi=po.dpi)
else:
plt.show()
else:
logger.info('Plot already exists! Please use --force to overwrite!')
[docs]def histplot_op(
ax, data, reference=None, alpha=.35, color=None, cmap=None, bins=None,
tstd=None, qlist=[0.01, 99.99], cbounds=None, kwargs={}):
"""
Modified from pymc3. Additional color argument.
"""
if color is not None and cmap is not None:
logger.debug('Using color for histogram edgecolor ...')
if cmap is not None:
from matplotlib.colors import Colormap
if not isinstance(cmap, Colormap):
raise TypeError(
'The colormap needs to be a valid matplotlib colormap!')
histtype = 'bar'
else:
histtype = 'stepfilled'
for i in range(data.shape[1]):
d = data[:, i]
quants = quantiles(d, qlist=qlist)
mind = quants[qlist[0]]
maxd = quants[qlist[-1]]
if reference is not None:
mind = num.minimum(mind, reference)
maxd = num.maximum(maxd, reference)
if tstd is None:
tstd = num.std(d)
step = (maxd - mind) / 40.
if bins is None:
bins = int(num.ceil((maxd - mind) / step))
major, minor = get_matplotlib_version()
if major < 3:
kwargs['normed'] = True
else:
kwargs['density'] = True
n, outbins, patches = ax.hist(
d, bins=bins, stacked=True, alpha=alpha,
align='left', histtype=histtype, color=color, edgecolor=color,
**kwargs)
if cmap:
bin_centers = 0.5 * (outbins[:-1] + outbins[1:])
if cbounds is None:
col = bin_centers - min(bin_centers)
col /= max(col)
else:
col = (bin_centers - cbounds[0]) / (cbounds[1] - cbounds[0])
for c, p in zip(col, patches):
plt.setp(p, 'facecolor', cmap(c))
left, right = ax.get_xlim()
leftb = mind - tstd
rightb = maxd + tstd
if left != 0.0 or right != 1.0:
leftb = num.minimum(leftb, left)
rightb = num.maximum(rightb, right)
ax.set_xlim(leftb, rightb)
def unify_tick_intervals(axs, varnames, ntickmarks_max=5, axis='x'):
"""
Take figure axes objects and determine unit ranges between common
unit classes (see utility.grouped_vars). Assures that the number of
increments is not larger than ntickmarks_max. Will thus overwrite
Returns
-------
dict : with types_sets keys and (min_range, max_range) as values
"""
unities = {}
for setname in utility.unit_sets.keys():
unities[setname] = [num.inf, -num.inf]
def extract_type_range(ax, varname, unities):
for setname, ranges in unities.items():
if axis == 'x':
varrange = num.diff(ax.get_xlim())
elif axis == 'y':
varrange = num.diff(ax.get_ylim())
else:
raise ValueError('Only "x" or "y" allowed!')
tset = utility.unit_sets[setname]
min_range, max_range = ranges
if varname in tset:
new_ranges = copy.deepcopy(ranges)
if varrange < min_range:
new_ranges[0] = varrange
if varrange > max_range:
new_ranges[1] = varrange
unities[setname] = new_ranges
for ax, varname in zip(axs.ravel('F'), varnames):
extract_type_range(ax, varname, unities)
for setname, ranges in unities.items():
min_range, max_range = ranges
max_range_frac = max_range / ntickmarks_max
if max_range_frac > min_range:
logger.debug(
'Range difference between min and max for %s is large!'
' Extending min_range to %f' % (setname, max_range_frac))
unities[setname] = [max_range_frac, max_range]
return unities
def apply_unified_axis(axs, varnames, unities, axis='x', ntickmarks_max=3,
scale_factor=2 / 3):
for ax, v in zip(axs.ravel('F'), varnames):
if v in utility.grouped_vars:
for setname, varrange in unities.items():
if v in utility.unit_sets[setname]:
inc = nice_value(varrange[0] * scale_factor)
autos = AutoScaler(
inc=inc, snap='on', approx_ticks=ntickmarks_max)
if axis == 'x':
min, max = ax.get_xlim()
elif axis == 'y':
min, max = ax.get_ylim()
min, max, sinc = autos.make_scale(
(min, max), override_mode='min-max')
# check physical bounds if passed truncate
phys_min, phys_max = physical_bounds[v]
if min < phys_min:
min = phys_min
if max > phys_max:
max = phys_max
if axis == 'x':
ax.set_xlim((min, max))
elif axis == 'y':
ax.set_ylim((min, max))
ticks = num.arange(min, max + inc, inc).tolist()
if axis == 'x':
ax.xaxis.set_ticks(ticks)
elif axis == 'y':
ax.yaxis.set_ticks(ticks)
else:
ticker = tick.MaxNLocator(nbins=3)
if axis == 'x':
ax.get_xaxis().set_major_locator(ticker)
elif axis == 'y':
ax.get_yaxis().set_major_locator(ticker)
[docs]def traceplot(trace, varnames=None, transform=lambda x: x, figsize=None,
lines={}, chains=None, combined=False, grid=False,
varbins=None, nbins=40, color=None, source_idxs=None,
alpha=0.35, priors=None, prior_alpha=1, prior_style='--',
axs=None, posterior=None, fig=None, plot_style='kde',
prior_bounds={}, unify=True, qlist=[0.1, 99.9], kwargs={}):
"""
Plots posterior pdfs as histograms from multiple mtrace objects.
Modified from pymc3.
Parameters
----------
trace : result of MCMC run
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
transform : callable
Function to transform data (defaults to identity)
posterior : str
To mark posterior value in distribution 'max', 'min', 'mean', 'all'
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
lines : dict
Dictionary of variable name / value to be overplotted as vertical
lines to the posteriors and horizontal lines on sample values
e.g. mean of posteriors, true values of a simulation
chains : int or list of ints
chain indexes to select from the trace
combined : bool
Flag for combining multiple chains into a single chain. If False
(default), chains will be plotted separately.
source_idxs : list
array like, indexes to sources to plot marginals
grid : bool
Flag for adding gridlines to histogram. Defaults to True.
varbins : list of arrays
List containing the binning arrays for the variables, if None they will
be created.
nbins : int
Number of bins for each histogram
color : tuple
mpl color tuple
alpha : float
Alpha value for plot line. Defaults to 0.35.
axs : axes
Matplotlib axes. Defaults to None.
fig : figure
Matplotlib figure. Defaults to None.
unify : bool
If true axis units that belong to one group e.g. [km] will
have common axis increments
kwargs : dict
for histplot op
qlist : list
of quantiles to plot. Default: (all, 0., 100.)
Returns
-------
ax : matplotlib axes
"""
ntickmarks = 2
fontsize = 10
ntickmarks_max = kwargs.pop('ntickmarks_max', 3)
scale_factor = kwargs.pop('scale_factor', 2 / 3)
lines_color = kwargs.pop('lines_color', 'k')
num.set_printoptions(precision=3)
def make_bins(data, nbins=40, qlist=None):
d = data.flatten()
if qlist is not None:
qu = quantiles(d, qlist=qlist)
mind = qu[qlist[0]]
maxd = qu[qlist[-1]]
else:
mind = d.min()
maxd = d.max()
return num.linspace(mind, maxd, nbins)
def remove_var(varnames, varname):
idx = varnames.index(varname)
varnames.pop(idx)
if varnames is None:
varnames = [name for name in trace.varnames if not name.endswith('_')]
if 'geo_like' in varnames:
remove_var(varnames, varname='geo_like')
if 'seis_like' in varnames:
remove_var(varnames, varname='seis_like')
if posterior != 'None':
llk = trace.get_values(
'like', combine=combined, chains=chains, squeeze=False)
llk = num.squeeze(transform(llk[0]))
llk = pmp.utils.make_2d(llk)
posterior_idxs = utility.get_fit_indexes(llk)
colors = {
'mean': scolor('orange1'),
'min': scolor('butter1'),
'max': scolor('scarletred2')}
n = len(varnames)
nrow = int(num.ceil(n / 2.))
ncol = 2
n_fig = nrow * ncol
if figsize is None:
if n < 5:
figsize = mpl_papersize('a6', 'landscape')
elif n < 7:
figsize = mpl_papersize('a5', 'portrait')
else:
figsize = mpl_papersize('a4', 'portrait')
if axs is None:
fig, axs = plt.subplots(nrow, ncol, figsize=figsize)
axs = num.atleast_2d(axs)
elif axs.shape != (nrow, ncol):
raise TypeError('traceplot requires n*2 subplots %i, %i' % (
nrow, ncol))
if varbins is None:
make_bins_flag = True
varbins = []
else:
make_bins_flag = False
input_color = copy.deepcopy(color)
for i in range(n_fig):
coli, rowi = utility.mod_i(i, nrow)
if i > len(varnames) - 1:
try:
fig.delaxes(axs[rowi, coli])
except KeyError:
pass
else:
v = varnames[i]
color = copy.deepcopy(input_color)
for d in trace.get_values(
v, combine=combined, chains=chains, squeeze=False):
d = transform(d)
# iterate over columns in case varsize > 1
if v in dist_vars:
if source_idxs is None:
logger.info('No patches defined using 1 every 10!')
source_idxs = num.arange(0, d.shape[1], 10).tolist()
logger.info(
'Plotting patches: %s' % utility.list2string(
source_idxs))
try:
selected = d.T[source_idxs]
except IndexError:
raise IndexError(
'One or several patches do not exist! '
'Patch idxs: %s' % utility.list2string(
source_idxs))
else:
selected = d.T
for isource, e in enumerate(selected):
e = pmp.utils.make_2d(e)
if make_bins_flag:
varbin = make_bins(e, nbins=nbins, qlist=qlist)
varbins.append(varbin)
else:
varbin = varbins[i]
if lines:
if v in lines:
reference = lines[v]
else:
reference = None
else:
reference = None
if color is None:
pcolor = mpl_graph_color(isource)
else:
pcolor = color
if plot_style == 'kde':
pmp.kdeplot(
e, shade=alpha, ax=axs[rowi, coli],
color=color, linewidth=1.,
kwargs_shade={'color': pcolor})
axs[rowi, coli].relim()
axs[rowi, coli].autoscale(tight=False)
axs[rowi, coli].set_ylim(0)
xax = axs[rowi, coli].get_xaxis()
# axs[rowi, coli].set_ylim([0, e.max()])
xticker = tick.MaxNLocator(nbins=5)
xax.set_major_locator(xticker)
elif plot_style == 'hist':
histplot_op(
axs[rowi, coli], e, reference=reference,
bins=varbin, alpha=alpha, color=pcolor,
qlist=qlist, kwargs=kwargs)
else:
raise NotImplementedError(
'Plot style "%s" not implemented' % plot_style)
try:
param = prior_bounds[v]
if v in dist_vars:
try: # variable bounds
lower = param.lower[source_idxs]
upper = param.upper[source_idxs]
except IndexError:
lower, upper = param.lower, param.upper
title = '{} {}'.format(v, plot_units[hypername(v)])
else:
lower = num.array2string(
param.lower, separator=',')[1:-1]
upper = num.array2string(
param.upper, separator=',')[1:-1]
title = '{} {} priors: ({}; {})'.format(
v, plot_units[hypername(v)], lower, upper)
except KeyError:
try:
title = '{} {}'.format(v, float(lines[v]))
except KeyError:
title = '{} {}'.format(v, plot_units[hypername(v)])
axs[rowi, coli].set_xlabel(title, fontsize=fontsize)
axs[rowi, coli].grid(grid)
axs[rowi, coli].get_yaxis().set_visible(False)
format_axes(axs[rowi, coli])
axs[rowi, coli].tick_params(axis='x', labelsize=fontsize)
# axs[rowi, coli].set_ylabel("Frequency")
if lines:
try:
axs[rowi, coli].axvline(
x=lines[v], color=lines_color, lw=1.)
except KeyError:
pass
if posterior != 'None':
if posterior == 'all':
for k, idx in posterior_idxs.items():
axs[rowi, coli].axvline(
x=e[idx], color=colors[k], lw=1.)
else:
idx = posterior_idxs[posterior]
axs[rowi, coli].axvline(
x=e[idx], color=pcolor, lw=1.)
if unify:
unities = unify_tick_intervals(
axs, varnames, ntickmarks_max=ntickmarks_max, axis='x')
apply_unified_axis(axs, varnames, unities, axis='x',
scale_factor=scale_factor)
if source_idxs:
axs[0, 0].legend(source_idxs)
fig.tight_layout()
return fig, axs, varbins
def get_matplotlib_version():
from matplotlib import __version__ as mplversion
return float(mplversion[0]), float(mplversion[2:])
def draw_posteriors(problem, plot_options):
"""
Identify which stage is the last complete stage and plot posteriors.
"""
hypers = utility.check_hyper_flag(problem)
po = plot_options
stage = Stage(homepath=problem.outfolder,
backend=problem.config.sampler_config.backend)
pc = problem.config.problem_config
list_indexes = stage.handler.get_stage_indexes(po.load_stage)
if hypers:
varnames = problem.hypernames + ['like']
else:
varnames = problem.varnames + problem.hypernames + \
problem.hierarchicalnames + ['like']
if len(po.varnames) > 0:
varnames = po.varnames
logger.info('Plotting variables: %s' % (', '.join((v for v in varnames))))
figs = []
for s in list_indexes:
if po.source_idxs:
sidxs = utility.list2string(po.source_idxs, fill='_')
else:
sidxs = ''
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'stage_%i_%s_%s.%s' % (s, sidxs, po.post_llk, po.outformat))
if not os.path.exists(outpath) or po.force:
logger.info('plotting stage: %s' % stage.handler.stage_path(s))
stage.load_results(
varnames=problem.varnames,
model=problem.model, stage_number=s,
load='trace', chains=[-1])
prior_bounds = {}
prior_bounds.update(**pc.hyperparameters)
prior_bounds.update(**pc.hierarchicals)
prior_bounds.update(**pc.priors)
fig, _, _ = traceplot(
stage.mtrace,
varnames=varnames,
chains=None,
combined=True,
source_idxs=po.source_idxs,
plot_style='hist',
lines=po.reference,
posterior=po.post_llk,
prior_bounds=prior_bounds)
if not po.outformat == 'display':
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, format=po.outformat, dpi=po.dpi)
else:
figs.append(fig)
else:
logger.info(
'plot for stage %s exists. Use force=True for'
' replotting!' % s)
if po.outformat == 'display':
plt.show()
def draw_correlation_hist(problem, plot_options):
"""
Draw parameter correlation plot and histograms from the final atmip stage.
Only feasible for 'geometry' problem.
"""
#if problem.config.problem_config.n_sources > 1:
# raise NotImplementedError(
# 'correlation_hist plot not working (yet) for n_sources > 1')
po = plot_options
mode = problem.config.problem_config.mode
assert mode == geometry_mode_str
assert po.load_stage != 0
hypers = utility.check_hyper_flag(problem)
if hypers:
varnames = problem.hypernames
else:
varnames = list(problem.varnames) + problem.hypernames + ['like']
if len(po.varnames) > 0:
varnames = po.varnames
logger.info('Plotting variables: %s' % (', '.join((v for v in varnames))))
if len(varnames) < 2:
raise TypeError('Need at least two parameters to compare!'
'Found only %i variables! ' % len(varnames))
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
if not po.reference:
reference = get_result_point(stage.mtrace, po.post_llk)
llk_str = po.post_llk
else:
reference = po.reference
llk_str = 'ref'
outpath = os.path.join(
problem.outfolder, po.figure_dir, 'corr_hist_%s_%s' % (
stage.number, llk_str))
if not os.path.exists(outpath) or po.force:
figs, _ = correlation_plot_hist(
mtrace=stage.mtrace,
varnames=varnames,
cmap=plt.cm.gist_earth_r,
chains=None,
point=reference,
point_size=6)
else:
logger.info('correlation plot exists. Use force=True for replotting!')
return
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figures to %s' % outpath)
if po.outformat == 'pdf':
with PdfPages(outpath + '.pdf') as opdf:
for fig in figs:
opdf.savefig(fig)
else:
for i, fig in enumerate(figs):
fig.savefig(
'%s_%i.%s' % (outpath, i, po.outformat), dpi=po.dpi)
def n_model_plot(models, axes=None, draw_bg=True, highlightidx=[]):
"""
Plot cake layered earth models.
"""
fontsize = 10
if axes is None:
mpl_init(fontsize=fontsize)
fig, axes = plt.subplots(
nrows=1, ncols=1, figsize=mpl_papersize('a6', 'portrait'))
labelpos = mpl_margins(
fig, left=6, bottom=4, top=1.5, right=0.5, units=fontsize)
labelpos(axes, 2., 1.5)
def plot_profile(mod, axes, vp_c, vs_c, lw=0.5):
z = mod.profile('z')
vp = mod.profile('vp')
vs = mod.profile('vs')
axes.plot(vp, z, color=vp_c, lw=lw)
axes.plot(vs, z, color=vs_c, lw=lw)
cp.labelspace(axes)
cp.labels_model(axes=axes)
if draw_bg:
cp.sketch_model(models[0], axes=axes)
else:
axes.spines['right'].set_visible(False)
axes.spines['top'].set_visible(False)
ref_vp_c = scolor('aluminium5')
ref_vs_c = scolor('aluminium5')
vp_c = scolor('scarletred2')
vs_c = scolor('skyblue2')
for i, mod in enumerate(models):
plot_profile(
mod, axes, vp_c=light(vp_c, 0.3), vs_c=light(vs_c, 0.3), lw=1.)
for count, i in enumerate(sorted(highlightidx)):
if count == 0:
vpcolor = ref_vp_c
vscolor = ref_vs_c
else:
vpcolor = vp_c
vscolor = vs_c
plot_profile(
models[i], axes, vp_c=vpcolor, vs_c=vscolor, lw=2.)
ymin, ymax = axes.get_ylim()
xmin, xmax = axes.get_xlim()
xmin = 0.
my = (ymax - ymin) * 0.05
mx = (xmax - xmin) * 0.2
axes.set_ylim(ymax, ymin - my)
axes.set_xlim(xmin, xmax + mx)
return fig, axes
def load_earthmodels(store_superdir, store_ids, depth_max='cmb'):
ems = []
emr = []
for store_id in store_ids:
path = os.path.join(store_superdir, store_id, 'config')
config = load(filename=path)
em = config.earthmodel_1d.extract(depth_max=depth_max)
ems.append(em)
if config.earthmodel_receiver_1d is not None:
emr.append(config.earthmodel_receiver_1d)
return [ems, emr]
def draw_earthmodels(problem, plot_options):
po = plot_options
for datatype, composite in problem.composites.items():
if datatype == 'seismic':
models_dict = {}
sc = problem.config.seismic_config
if sc.gf_config.reference_location is None:
plot_stations = composite.datahandler.stations
else:
plot_stations = [composite.datahandler.stations[0]]
plot_stations[0].station = \
sc.gf_config.reference_location.station
for station in plot_stations:
outbasepath = os.path.join(
problem.outfolder, po.figure_dir,
'%s_%s_velocity_model' % (
datatype, station.station))
if not os.path.exists(outbasepath) or po.force:
targets = init_seismic_targets(
[station],
earth_model_name=sc.gf_config.earth_model_name,
channels=sc.get_unique_channels()[0],
sample_rate=sc.gf_config.sample_rate,
crust_inds=list(range(*sc.gf_config.n_variations)),
interpolation='multilinear')
store_ids = [t.store_id for t in targets]
models = load_earthmodels(
composite.engine.store_superdirs[0], store_ids,
depth_max=sc.gf_config.depth_limit_variation * km)
for i, mods in enumerate(models):
if i == 0:
site = 'source'
elif i == 1:
site = 'receiver'
outpath = outbasepath + \
'_%s.%s' % (site, po.outformat)
models_dict[outpath] = mods
else:
logger.info(
'%s earthmodel plot for station %s exists. Use '
'force=True for replotting!' % (
datatype, station.station))
elif datatype == 'geodetic':
gc = problem.config.geodetic_config
models_dict = {}
outpath = os.path.join(
problem.outfolder, po.figure_dir,
'%s_%s_velocity_model.%s' % (
datatype, 'psgrn', po.outformat))
if not os.path.exists(outpath) or po.force:
targets = init_geodetic_targets(
datasets=composite.datasets,
earth_model_name=gc.gf_config.earth_model_name,
interpolation='multilinear',
crust_inds=list(range(*gc.gf_config.n_variations)),
sample_rate=gc.gf_config.sample_rate)
models = load_earthmodels(
store_superdir=composite.engine.store_superdirs[0],
targets=targets,
depth_max=gc.gf_config.source_depth_max * km)
models_dict[outpath] = models[0] # select only source site
else:
logger.info(
'%s earthmodel plot exists. Use force=True for'
' replotting!' % datatype)
return
else:
raise TypeError(
'Plot for datatype %s not (yet) supported' % datatype)
figs = []
axes = []
tobepopped = []
for path, models in models_dict.items():
if len(models) > 0:
fig, axs = n_model_plot(
models, axes=None,
draw_bg=po.reference, highlightidx=[0])
figs.append(fig)
axes.append(axs)
else:
tobepopped.append(path)
for entry in tobepopped:
models_dict.pop(entry)
if po.outformat == 'display':
plt.show()
else:
for fig, outpath in zip(figs, models_dict.keys()):
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, format=po.outformat, dpi=po.dpi)
def fuzzy_waveforms(
ax, traces, linewidth, zorder=0, extent=None,
grid_size=(500, 500), cmap=None, alpha=0.6):
"""
Fuzzy waveforms
traces : list
of class:`pyrocko.trace.Trace`, the times of the traces should not
vary too much
zorder : int
the higher number is drawn above the lower number
extent : list
of [xmin, xmax, ymin, ymax] (tmin, tmax, min/max of amplitudes)
if None, the default is to determine it from traces list
"""
if cmap is None:
from matplotlib.colors import LinearSegmentedColormap
ncolors = 256
cmap = LinearSegmentedColormap.from_list(
'dummy', ['white', scolor('chocolate2'), scolor('scarletred2')],
N=ncolors)
# cmap = plt.cm.gist_earth_r
if extent is None:
key = traces[0].channel
skey = lambda tr: tr.channel
ymin, ymax = trace.minmax(traces, key=skey)[key]
xmin, xmax = trace.minmaxtime(traces, key=skey)[key]
ymax = max(abs(ymin), abs(ymax))
ymin = -ymax
extent = [xmin, xmax, ymin, ymax]
grid = num.zeros(grid_size, dtype='float64')
for tr in traces:
draw_line_on_array(
tr.get_xdata(), tr.ydata,
grid=grid,
extent=extent,
grid_resolution=grid.shape,
linewidth=linewidth)
# increase contrast reduce high intense values
# truncate = len(traces) / 2
# grid[grid > truncate] = truncate
ax.imshow(
grid, extent=extent, origin='lower', cmap=cmap, aspect='auto',
alpha=alpha, zorder=zorder)
def fuzzy_rupture_fronts(
ax, rupture_fronts, xgrid, ygrid, alpha=0.6, linewidth=7, zorder=0):
"""
Fuzzy rupture fronts
rupture_fronts : list
of output of cs = pyplot.contour; cs.allsegs
xgrid : array_like
of center coordinates of the sub-patches of the fault in
strike-direction in [km]
ygrid : array_like
of center coordinates of the sub-patches of the fault in
dip-direction in [km]
"""
from matplotlib.colors import LinearSegmentedColormap
ncolors = 256
cmap = LinearSegmentedColormap.from_list(
'dummy', ['white', 'black'], N=ncolors)
res_km = 25 # pixel per km
xmin = xgrid.min()
xmax = xgrid.max()
ymin = ygrid.min()
ymax = ygrid.max()
extent = (xmin, xmax, ymin, ymax)
grid = num.zeros(
(int((num.abs(ymax) - num.abs(ymin)) * res_km),
int((num.abs(xmax) - num.abs(xmin)) * res_km)),
dtype='float64')
for rupture_front in rupture_fronts:
for level in rupture_front:
for line in level:
draw_line_on_array(
line[:, 0], line[:, 1],
grid=grid,
extent=extent,
grid_resolution=grid.shape,
linewidth=linewidth)
# increase contrast reduce high intense values
truncate = len(rupture_fronts) / 2
grid[grid > truncate] = truncate
ax.imshow(
grid, extent=extent, origin='lower', cmap=cmap, aspect='auto',
alpha=alpha, zorder=zorder)
def fault_slip_distribution(
fault, mtrace=None, transform=lambda x: x, alpha=0.9, ntickmarks=5,
reference=None, nensemble=1):
"""
Draw discretized fault geometry rotated to the 2-d view of the foot-wall
of the fault.
Parameters
----------
fault : :class:`ffi.fault.FaultGeometry`
"""
def draw_quivers(
ax, uperp, uparr, xgr, ygr, rake, color='black',
draw_legend=False, normalisation=None, zorder=0):
# positive uperp is always dip-normal- have to multiply -1
angles = num.arctan2(-uperp, uparr) * mt.r2d + rake
slips = num.sqrt((uperp ** 2 + uparr ** 2)).ravel()
if normalisation is None:
from beat.models.laplacian import distances
centers = num.vstack((xgr, ygr)).T
#interpatch_dists = distances(centers, centers)
normalisation = slips.max()
slips /= normalisation
slipsx = num.cos(angles * mt.d2r) * slips
slipsy = num.sin(angles * mt.d2r) * slips
# slip arrows of slip on patches
quivers = ax.quiver(
xgr.ravel(), ygr.ravel(), slipsx, slipsy,
units='dots', angles='xy', scale_units='xy', scale=1,
width=1., color=color, zorder=zorder)
if draw_legend:
quiver_legend_length = num.ceil(
num.max(slips * normalisation) * 10.) / 10.
# ax.quiverkey(
# quivers, 0.9, 0.8, quiver_legend_length,
# '{} [m]'.format(quiver_legend_length), labelpos='E',
# coordinates='figure')
return quivers, normalisation
def draw_patches(
ax, fault, subfault_idx, patch_values, cmap, alpha, cbounds=None,
xlim=None):
lls = fault.get_subfault_patch_attributes(
subfault_idx, attributes=['bottom_left'])
widths, lengths = fault.get_subfault_patch_attributes(
subfault_idx, attributes=['width', 'length'])
sf = fault.get_subfault(subfault_idx)
# subtract reference fault lower left and rotate
rot_lls = utility.rotate_coords_plane_normal(lls, sf)[:, 1::-1]
d_patches = []
for ll, width, length in zip(rot_lls, widths, lengths):
d_patches.append(
Rectangle(
ll, width=length, height=width, edgecolor='black'))
lower = rot_lls.min(axis=0)
pad = sf.length / km * 0.05
#xlim = [lower[0] - pad, lower[0] + sf.length / km + pad]
if xlim is None:
xlim = [lower[1] - pad, lower[1] + sf.width / km + pad]
ax.set_aspect(1)
#ax.set_xlim(*xlim)
ax.set_xlim(*xlim)
scale_y = {'scale': 1, 'offset': (-sf.width / km)}
scale_axes(ax.yaxis, **scale_y)
ax.set_xlabel(
'strike-direction [km]', fontsize=fontsize)
ax.set_ylabel(
'dip-direction [km]', fontsize=fontsize)
xticker = tick.MaxNLocator(nbins=ntickmarks)
yticker = tick.MaxNLocator(nbins=ntickmarks)
ax.get_xaxis().set_major_locator(xticker)
ax.get_yaxis().set_major_locator(yticker)
pa_col = PatchCollection(
d_patches, alpha=alpha, match_original=True, zorder=0)
pa_col.set(array=patch_values, cmap=cmap)
if cbounds is not None:
pa_col.set_clim(*cbounds)
ax.add_collection(pa_col)
return pa_col
def draw_colorbar(fig, ax, cb_related, labeltext, ntickmarks=4):
cbaxes = fig.add_axes([0.88, 0.4, 0.03, 0.3])
cb = fig.colorbar(cb_related, ax=axs, cax=cbaxes)
cb.set_label(labeltext, fontsize=fontsize)
cb.locator = tick.MaxNLocator(nbins=ntickmarks)
cb.update_ticks()
ax.set_aspect('equal', adjustable='box')
def get_values_from_trace(mtrace, fault, varname, reference):
try:
u = transform(
mtrace.get_values(
varname, combine=True, squeeze=True))
except(ValueError, KeyError):
u = num.atleast_2d(fault.var_from_point(
index=None, point=reference, varname=varname))
return u
from beat.colormap import slip_colormap
fontsize = 12
reference_slip = fault.get_total_slip(index=None, point=reference)
slip_bounds = [0, reference_slip.max()]
figs = []
axs = []
flengths_max = num.array(
[sf.length / km for sf in fault.iter_subfaults()]).max()
pad = flengths_max * 0.03
xmax = flengths_max + pad
for ns in range(fault.nsubfaults):
fig, ax = plt.subplots(
nrows=1, ncols=1, figsize=mpl_papersize('a5', 'landscape'))
# alphas = alpha * num.ones(np_h * np_w, dtype='int8')
try:
ext_source = fault.get_subfault(ns, component='uparr')
except TypeError:
ext_source = fault.get_subfault(ns, component='utens')
patch_idxs = fault.get_patch_indexes(ns)
pa_col = draw_patches(
ax, fault,
subfault_idx=ns,
patch_values=reference_slip[patch_idxs], xlim=[-pad, xmax],
cmap=slip_colormap(100), alpha=0.65, cbounds=slip_bounds)
# patch central locations
centers = fault.get_subfault_patch_attributes(
ns, attributes=['center'])
rot_centers = utility.rotate_coords_plane_normal(
centers, ext_source)[:, 1::-1]
xgr, ygr = rot_centers.T
if 'seismic' in fault.datatypes:
shp = fault.ordering.get_subfault_discretization(ns)
xgr = xgr.reshape(shp)
ygr = ygr.reshape(shp)
if mtrace is not None:
from tqdm import tqdm
nuc_dip = transform(mtrace.get_values(
'nucleation_dip', combine=True, squeeze=True))
nuc_strike = transform(mtrace.get_values(
'nucleation_strike', combine=True, squeeze=True))
velocities = transform(mtrace.get_values(
'velocities', combine=True, squeeze=True))
nchains = len(mtrace)
csteps = 6
rupture_fronts = []
dummy_fig, dummy_ax = plt.subplots(
nrows=1, ncols=1, figsize=mpl_papersize('a5', 'landscape'))
csteps = float(nchains) / nensemble
idxs = num.floor(
num.arange(0, nchains, csteps)).astype('int32')
logger.info('Rendering rupture fronts ...')
for i in tqdm(idxs):
nuc_dip_idx, nuc_strike_idx = fault.fault_locations2idxs(
ns, nuc_dip[i], nuc_strike[i], backend='numpy')
veloc_ns = fault.vector2subfault(
index=ns, vector=velocities[i, :])
sts = fault.get_subfault_starttimes(
ns, veloc_ns, nuc_dip_idx[ns], nuc_strike_idx[ns])
contours = dummy_ax.contour(xgr, ygr, sts)
rupture_fronts.append(contours.allsegs)
fuzzy_rupture_fronts(
ax, rupture_fronts, xgr, ygr,
alpha=1., linewidth=7, zorder=-1)
durations = transform(mtrace.get_values(
'durations', combine=True, squeeze=True))
std_durations = durations.std(axis=0)
# alphas = std_durations.min() / std_durations
# rupture durations
if False:
fig2, ax2 = plt.subplots(
nrows=1, ncols=1,
figsize=mpl_papersize('a5', 'landscape'))
reference_durations = reference['durations'][patch_idxs]
pa_col2 = draw_patches(
ax2, fault, subfault_idx=ns,
patch_values=reference_durations,
cmap=plt.cm.seismic, alpha=alpha, xlim=[-pad, xmax])
draw_colorbar(fig2, ax2, pa_col2, labeltext='durations [s]')
figs.append(fig2)
axs.append(ax2)
ref_starttimes = fault.point2starttimes(reference, index=ns)
contours = ax.contour(
xgr, ygr, ref_starttimes,
colors='black', linewidths=0.5, alpha=0.9)
# draw subfault hypocenter
dip_idx, strike_idx = fault.fault_locations2idxs(
ns,
reference['nucleation_dip'][ns],
reference['nucleation_strike'][ns],
backend='numpy')
psize_strike = fault.ordering.patch_sizes_strike[ns]
psize_dip = fault.ordering.patch_sizes_dip[ns]
nuc_strike = strike_idx * psize_strike + (psize_strike / 2.)
nuc_dip = dip_idx * psize_dip + (psize_dip / 2.)
ax.plot(
nuc_strike, ext_source.width / km - nuc_dip,
marker='*', color='k', markersize=12)
# label contourlines
plt.clabel(contours, inline=True, fontsize=10,
fmt=tick.FormatStrFormatter('%.1f'))
if mtrace is not None:
logger.info('Drawing quantiles ...')
uparr = get_values_from_trace(
mtrace, fault, 'uparr', reference)[:, patch_idxs]
uperp = get_values_from_trace(
mtrace, fault, 'uperp', reference)[:, patch_idxs]
utens = get_values_from_trace(
mtrace, fault, 'utens', reference)[:, patch_idxs]
uparrmean = uparr.mean(axis=0)
uperpmean = uperp.mean(axis=0)
utensmean = utens.mean(axis=0)
if uparrmean.sum() != 0.:
logger.info('Found slip shear components!')
normalisation = slip_bounds[1] / 3
quivers, normalisation = draw_quivers(
ax, uperpmean, uparrmean, xgr, ygr,
ext_source.rake, color='grey',
draw_legend=False, normalisation=normalisation)
uparrstd = uparr.std(axis=0) / normalisation
uperpstd = uperp.std(axis=0) / normalisation
elif utensmean.sum() != 0:
logger.info(
'Found tensile slip components! Not drawing quivers!'
' Circle radius shows standard deviations!')
uperpstd = uparrstd = utens.std(axis=0)
normalisation = utens.max()
quivers = None
slipvecrotmat = mt.euler_to_matrix(
0.0, 0.0, ext_source.rake * mt.d2r)
circle = num.linspace(0, 2 * num.pi, 100)
# 2sigma error ellipses
for i, (upe, upa) in enumerate(zip(uperpstd, uparrstd)):
ellipse_x = 2 * upa * num.cos(circle)
ellipse_y = 2 * upe * num.sin(circle)
ellipse = num.vstack(
[ellipse_x, ellipse_y, num.zeros_like(ellipse_x)]).T
rot_ellipse = ellipse.dot(slipvecrotmat)
xcoords = xgr.ravel()[i] + rot_ellipse[:, 0]
ycoords = ygr.ravel()[i] + rot_ellipse[:, 1]
if quivers is not None:
xcoords += quivers.U[i]
ycoords += quivers.V[i]
ax.plot(xcoords, ycoords, '-k', linewidth=0.5, zorder=2)
else:
normalisation = None
uperp = reference['uperp'][patch_idxs]
uparr = reference['uparr'][patch_idxs]
if uparr.mean() != 0.:
logger.info('Drawing slip vectors ...')
draw_quivers(
ax, uperp, uparr,
xgr, ygr, ext_source.rake, color='black', draw_legend=True,
normalisation=normalisation, zorder=3)
draw_colorbar(fig, ax, pa_col, labeltext='slip [m]')
format_axes(ax, remove=['top', 'right'])
# fig.tight_layout()
figs.append(fig)
axs.append(ax)
return figs, axs
class ModeError(Exception):
pass
def draw_slip_dist(problem, po):
mode = problem.config.problem_config.mode
if mode != ffi_mode_str:
raise ModeError(
'Wrong optimization mode: %s! This plot '
'variant is only valid for "%s" mode' % (mode, ffi_mode_str))
datatype, gc = list(problem.composites.items())[0]
fault = gc.load_fault_geometry()
if not po.reference:
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
reference = problem.config.problem_config.get_test_point()
res_point = get_result_point(stage.mtrace, po.post_llk)
reference.update(res_point)
llk_str = po.post_llk
mtrace = stage.mtrace
stage_number = stage.number
else:
reference = po.reference
llk_str = 'ref'
mtrace = None
stage_number = -1
figs, axs = fault_slip_distribution(
fault, mtrace, reference=reference, nensemble=po.nensemble)
if po.outformat == 'display':
plt.show()
else:
outpath = os.path.join(
problem.outfolder, po.figure_dir,
'slip_dist_%i_%s_%i' % (stage_number, llk_str, po.nensemble))
logger.info('Storing slip-distribution to: %s' % outpath)
if po.outformat == 'pdf':
with PdfPages(outpath + '.pdf') as opdf:
for fig in figs:
opdf.savefig(fig, dpi=po.dpi)
else:
for i, fig in enumerate(figs):
fig.savefig(outpath + '_%i.%s' % (i, po.outformat), dpi=po.dpi)
def _weighted_line(
r0, c0, r1, c1, w, rmin=0, rmax=num.inf, cmin=0, cmax=num.inf):
"""
Draw weighted lines into array
Modiefied from:
https://stackoverflow.com/questions/31638651/how-can-i-draw-lines-into-numpy-arrays
Parameters
----------
r0 : int
row index for line end point 0
c0 : int
col index for line end point 0
r1 : int
row index for line end point 1
c1 : int
col index for line end point 1
w : int
width in pixels for line
rmin : int
min row index for grid to draw on
rmax : int
max row index for grid to draw on
Returns
-------
rr : array of row indexes of line
cc : array of col indexes of line
w : array of line weights
"""
def trapez(y, y0, w):
return num.clip(num.minimum(
y + 1 + w / 2 - y0,
- y + 1 + w / 2 + y0), 0, 1)
# The algorithm below works fine if c1 >= c0 and c1-c0 >= abs(r1-r0).
# If either of these cases are violated, do some switches.
if abs(c1 - c0) < abs(r1 - r0):
# Switch x and y, and switch again when returning.
xx, yy, val = _weighted_line(
c0, r0, c1, r1, w=w, rmin=cmin, rmax=cmax, cmin=rmin, cmax=rmax)
return (yy, xx, val)
# At this point we know that the distance in columns (x) is greater
# than that in rows (y). Possibly one more switch if c0 > c1.
if c0 > c1:
return _weighted_line(
r1, c1, r0, c0, w=w, rmin=rmin, rmax=rmax, cmin=cmin, cmax=cmax)
# The following is now always < 1 in abs
slope = (r1 - r0) / (c1 - c0)
# Adjust weight by the slope
w *= num.sqrt(1 + num.abs(slope)) / 2
# We write y as a function of x, because the slope is always <= 1
# (in absolute value)
x = num.arange(c0, c1 + 1, dtype=float)
y = (x * slope) + ((c1 * r0) - (c0 * r1)) / (c1 - c0)
# Now instead of 2 values for y, we have 2*np.ceil(w/2).
# All values are 1 except the upmost and bottommost.
thickness = num.ceil(w / 2)
yy = (num.floor(y).reshape(-1, 1) +
num.arange(-thickness - 1, thickness + 2).reshape(1, -1))
xx = num.repeat(x, yy.shape[1])
vals = trapez(yy, y.reshape(-1, 1), w).flatten()
yy = yy.flatten()
# Exclude useless parts and those outside of the interval
# to avoid parts outside of the picture
mask_y = num.logical_and.reduce((yy >= rmin, yy < rmax, vals > 0))
mask_x = num.logical_and.reduce((xx >= cmin, xx < cmax, vals > 0))
mask = num.logical_and.reduce((mask_y > 0, mask_x > 0))
return (yy[mask].astype(int), xx[mask].astype(int), vals[mask])
def draw_line_on_array(
X, Y, grid=None, extent=[], grid_resolution=(400, 400), linewidth=1):
"""
Draw line on given array by adding 1 to its fields.
Parameters
----------
X : array_like
timeseries on xcoordinate (columns of array)
Y : array_like
timeseries on ycoordinate (rows of array)
grid : array_like 2d
input array that is used for drawing
extent : array extent
[xmin, xmax, ymin, ymax] (cols, rows)
grid_resolution : tuple
shape of given grid or grid that is being used for allocation
linewidth : int
weight (width) of line drawn on grid
Returns
-------
grid, extent
"""
def check_grid_shape(ngr, naim, axis):
if ngr != naim:
raise TypeError(
'Gridsize of given grid is inconistent for axis %i!'
' Expected %i got %i' % (axis, naim, ngr))
def check_line_in_grid(idxs, axis, nmax, extent):
imax = idxs.max()
if imax > nmax:
raise TypeError(
'Line endpoint outside of given grid Axis "%s"! %i > %i'
' Extent [%s]' % (
axis, imax, nmax, utility.list2string(extent)))
nxs = len(X)
nys = len(Y)
if nxs != nys:
raise TypeError(
'Length of X and Y have to be identical! %i != %i' % (nxs, nys))
if len(extent) == 0:
xmin = X.min()
xmax = X.max()
ymin = Y.min()
ymax = Y.max()
extent = [xmin, xmax, ymin, ymax]
elif len(extent) == 4:
xmin, xmax, ymin, ymax = extent
else:
raise TypeError(
'extent has to be of length 4! [xmin, xmax, ymin, ymax]')
if len(grid_resolution) != 2:
raise TypeError(
'grid_resolution has to be of length 2! [xstep, ystep]!')
ynstep, xnstep = grid_resolution
xvec, xstep = num.linspace(xmin, xmax, xnstep, endpoint=True, retstep=True)
yvec, ystep = num.linspace(ymin, ymax, ynstep, endpoint=True, retstep=True)
if grid is not None:
if grid.ndim != 2:
raise TypeError('Given grid has to be of dimension 2!')
for axis, (ngr, naim) in enumerate(
zip(grid.shape, grid_resolution)):
check_grid_shape(ngr, naim, axis)
else:
grid = num.zeros((ynstep, xnstep), dtype='float64')
xidxs = utility.positions2idxs(
X, min_pos=xmin, cell_size=xstep, dtype='int32')
yidxs = utility.positions2idxs(
Y, min_pos=ymin, cell_size=ystep, dtype='int32')
check_line_in_grid(xidxs, 'x', nmax=xnstep - 1, extent=extent)
check_line_in_grid(yidxs, 'y', nmax=ynstep - 1, extent=extent)
new_grid = num.zeros_like(grid)
for i in range(1, nxs):
c0 = xidxs[i - 1]
r0 = yidxs[i - 1]
c1 = xidxs[i]
r1 = yidxs[i]
try:
rr, cc, w = _weighted_line(
r0=r0, c0=c0, r1=r1, c1=c1, w=linewidth,
rmax=ynstep - 1, cmax=xnstep - 1)
new_grid[rr, cc] = w.astype(grid.dtype)
except ValueError:
# line start and end fall in the same grid point cant be drawn
pass
grid += new_grid
return grid, extent
def fuzzy_moment_rate(
ax, moment_rates, times, cmap=None, grid_size=(500, 500)):
"""
Plot fuzzy moment rate function into axes.
"""
if cmap is None:
# from matplotlib.colors import LinearSegmentedColormap
# ncolors = 256
# cmap = LinearSegmentedColormap.from_list(
# 'dummy', [background_color, rates_color], N=ncolors)
cmap = plt.cm.hot_r
nrates = len(moment_rates)
ntimes = len(times)
if nrates != ntimes:
raise TypeError(
'Number of rates and times have to be identical!'
' %i != %i' % (nrates, ntimes))
max_rates = max(map(num.max, moment_rates))
max_times = max(map(num.max, times))
min_rates = min(map(num.min, moment_rates))
min_times = min(map(num.min, times))
extent = (min_times, max_times, min_rates, max_rates)
grid = num.zeros(grid_size, dtype='float64')
for mr, time in zip(moment_rates, times):
draw_line_on_array(
time, mr,
grid=grid,
extent=extent,
grid_resolution=grid.shape,
linewidth=7)
# increase contrast reduce high intense values
truncate = nrates / 2
grid[grid > truncate] = truncate
ax.imshow(grid, extent=extent, origin='lower', cmap=cmap, aspect='auto')
xticker = tick.MaxNLocator(nbins=5)
yticker = tick.MaxNLocator(nbins=5)
ax.xaxis.set_major_locator(xticker)
ax.yaxis.set_major_locator(yticker)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Moment rate [$Nm / s$]')
def draw_moment_rate(problem, po):
"""
Draw moment rate function for the results of a seismic/joint finite fault
optimization.
"""
fontsize = 12
mode = problem.config.problem_config.mode
if mode != ffi_mode_str:
raise ModeError(
'Wrong optimization mode: %s! This plot '
'variant is only valid for "%s" mode' % (mode, ffi_mode_str))
if 'seismic' not in problem.config.problem_config.datatypes:
raise TypeError(
'Moment rate function only available for optimization results that'
' include seismic data.')
sc = problem.composites['seismic']
fault = sc.load_fault_geometry()
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
if not po.reference:
reference = get_result_point(stage.mtrace, po.post_llk)
llk_str = po.post_llk
mtrace = stage.mtrace
else:
reference = po.reference
llk_str = 'ref'
mtrace = None
logger.info(
'Drawing ensemble of %i moment rate functions ...' % po.nensemble)
target = sc.wavemaps[0].targets[0]
if po.plot_projection == 'individual':
logger.info('Drawing subfault individual rates ...')
sf_idxs = range(fault.nsubfaults)
else:
logger.info('Drawing total rates ...')
sf_idxs = [list(range(fault.nsubfaults))]
mpl_init(fontsize=fontsize)
for i, ns in enumerate(sf_idxs):
logger.info('Fault %i / %i' % (i + 1, len(sf_idxs)))
if isinstance(ns, list):
ns_str = 'total'
else:
ns_str = str(ns)
outpath = os.path.join(
problem.outfolder, po.figure_dir,
'moment_rate_%i_%s_%s_%i.%s' % (
stage.number, ns_str,
llk_str, po.nensemble, po.outformat))
ref_mrf_rates, ref_mrf_times = fault.get_moment_rate_function(
index=ns, point=reference, target=target,
store=sc.engine.get_store(target.store_id))
if not os.path.exists(outpath) or po.force:
fig, ax = plt.subplots(
nrows=1, ncols=1, figsize=mpl_papersize('a6', 'landscape'))
labelpos = mpl_margins(
fig, left=5, bottom=4, top=1.5, right=0.5, units=fontsize)
labelpos(ax, 2., 1.5)
if mtrace is not None:
nchains = len(mtrace)
csteps = float(nchains) / po.nensemble
idxs = num.floor(
num.arange(0, nchains, csteps)).astype('int32')
mrfs_rate = []
mrfs_time = []
for idx in idxs:
point = mtrace.point(idx=idx)
mrf_rate, mrf_time = \
fault.get_moment_rate_function(
index=ns, point=point, target=target,
store=sc.engine.get_store(target.store_id))
mrfs_rate.append(mrf_rate)
mrfs_time.append(mrf_time)
fuzzy_moment_rate(ax, mrfs_rate, mrfs_time)
ax.plot(
ref_mrf_times, ref_mrf_rates,
'-k', alpha=0.8, linewidth=1.)
format_axes(ax, remove=['top', 'right'])
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, format=po.outformat, dpi=po.dpi)
else:
logger.info('Plot exists! Use --force to overwrite!')
def source_geometry(fault, ref_sources, event, datasets=None, values=None,
cmap=None, title=None, show=True, cbounds=None, clabel=''):
"""
Plot source geometry in 3d rotatable view
Parameters
----------
fault: :class:`beat.ffi.fault.FaultGeometry`
ref_sources: list
of :class:'beat.sources.RectangularSource'
"""
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
alpha = 0.7
def plot_subfault(ax, source, color, refloc):
source.anchor = 'top'
shift_ne = otd.latlon_to_ne(
refloc.lat, refloc.lon, source.lat, source.lon)
coords = source.outline() # (N, E, Z)
coords[:, 0:2] += shift_ne
ax.plot(
coords[:, 1], coords[:, 0], coords[:, 2] * -1.,
color=color, linewidth=2, alpha=alpha)
ax.plot(
coords[0:2, 1], coords[0:2, 0], coords[0:2, 2] * -1.,
'-k', linewidth=2, alpha=alpha)
center = source.center # (E, N, Z)
center[0] += shift_ne[1]
center[1] += shift_ne[0]
ax.scatter(
center[0], center[1], center[2] * -1,
marker='o', s=20, color=color, alpha=alpha)
def set_axes_radius(ax, origin, radius, axes=['xyz']):
if 'x' in axes:
ax.set_xlim3d([origin[0] - radius, origin[0] + radius])
if 'y' in axes:
ax.set_ylim3d([origin[1] - radius, origin[1] + radius])
if 'z' in axes:
ax.set_zlim3d([origin[2] - radius, origin[2] + radius])
def set_axes_equal(ax, axes='xyz'):
'''
Make axes of 3D plot have equal scale so that spheres appear as
spheres, cubes as cubes, etc..
This is one possible solution to Matplotlib's
ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
Input
ax: a matplotlib axis, e.g., as output from plt.gca().
'''
limits = num.array([
ax.get_xlim3d(),
ax.get_ylim3d(),
ax.get_zlim3d(),
])
origin = num.mean(limits, axis=1)
radius = 0.5 * num.max(num.abs(limits[:, 1] - limits[:, 0]))
set_axes_radius(ax, origin, radius, axes=axes)
fig = plt.figure(figsize=mpl_papersize('a5', 'landscape'))
ax = fig.add_subplot(111, projection='3d')
extfs = fault.get_all_subfaults()
arr_coords = []
for idx, (refs, exts) in enumerate(zip(ref_sources, extfs)):
plot_subfault(ax, exts, color=mpl_graph_color(idx), refloc=event)
plot_subfault(ax, refs, color=scolor('aluminium4'), refloc=event)
for i, patch in enumerate(fault.get_subfault_patches(idx)):
coords = patch.outline()
shift_ne = otd.latlon_to_ne(
event.lat, event.lon, patch.lat, patch.lon)
coords[:, 0:2] += shift_ne
coords[:, 2] *= -1.
coords[:, [0, 1]] = coords[:, [1, 0]] # swap columns to [E, N, Z] (X, Y, Z)
arr_coords.append(coords)
ax.plot(
coords[:, 0], coords[:, 1], coords[:, 2], zorder=2,
color=mpl_graph_color(idx), linewidth=0.5, alpha=alpha)
ax.text(
patch.east_shift + shift_ne[1],
patch.north_shift + shift_ne[0], patch.center[2] * -1.,
str(i + fault.cum_subfault_npatches[idx]), zorder=3,
fontsize=8)
if values is not None:
if cmap is None:
cmap = plt.cm.get_cmap('RdYlBu_r')
poly_patches = Poly3DCollection(
verts=arr_coords, zorder=1, cmap=cmap)
poly_patches.set_array(values)
if cbounds is None:
poly_patches.set_clim(values.min(), values.max())
else:
poly_patches.set_clim(*cbounds)
poly_patches.set_alpha(0.6)
poly_patches.set_edgecolor('k')
ax.add_collection(poly_patches)
cbs = plt.colorbar(
poly_patches,
ax=ax,
orientation='vertical',
cmap=cmap)
if clabel is not None:
cbs.set_label(clabel)
if datasets:
for dataset in datasets:
# print(dataset.east_shifts, dataset.north_shifts)
ax.scatter(
dataset.east_shifts,
dataset.north_shifts,
dataset.coords5[:, 4],
s=10, alpha=0.6, marker='o', color='black')
scale = {'scale': 1. / km}
scale_axes(ax.xaxis, **scale)
scale_axes(ax.yaxis, **scale)
scale_axes(ax.zaxis, **scale)
ax.set_zlabel('Depth [km]')
ax.set_ylabel('North_shift [km]')
ax.set_xlabel('East_shift [km]')
set_axes_equal(ax, axes='xy')
strikes = num.array([extf.strike for extf in extfs])
dips = num.array([extf.strike for extf in extfs])
azim = strikes.mean() - 270
elev = dips.mean()
logger.debug(
'Viewing azimuth %s and elevation angles %s', azim, ax.elev)
ax.view_init(ax.elev, azim)
if title is not None:
ax.set_title(title)
if show:
plt.show()
return fig, ax
def get_gmt_config(gmtpy, fontsize=14, h=20., w=20.):
if gmtpy.is_gmt5(version='newest'):
gmtconfig = {
'MAP_GRID_PEN_PRIMARY': '0.1p',
'MAP_GRID_PEN_SECONDARY': '0.1p',
'MAP_FRAME_TYPE': 'fancy',
'FONT_ANNOT_PRIMARY': '%ip,Helvetica,black' % fontsize,
'FONT_ANNOT_SECONDARY': '%ip,Helvetica,black' % fontsize,
'FONT_LABEL': '%ip,Helvetica,black' % fontsize,
'FORMAT_GEO_MAP': 'D',
'GMT_TRIANGULATE': 'Watson',
'PS_MEDIA': 'Custom_%ix%i' % (w * gmtpy.cm, h * gmtpy.cm),
}
else:
gmtconfig = {
'MAP_FRAME_TYPE': 'fancy',
'GRID_PEN_PRIMARY': '0.01p',
'ANNOT_FONT_PRIMARY': '1',
'ANNOT_FONT_SIZE_PRIMARY': '12p',
'PLOT_DEGREE_FORMAT': 'D',
'GRID_PEN_SECONDARY': '0.01p',
'FONT_LABEL': '%ip,Helvetica,black' % fontsize,
'PS_MEDIA': 'Custom_%ix%i' % (w * gmtpy.cm, h * gmtpy.cm),
}
return gmtconfig
def draw_data_stations(
gmt, stations, data, dist, data_cpt=None,
scale_label=None, *args):
"""
Draw MAP time-shifts at station locations as colored triangles
"""
miny = data.min()
maxy = data.max()
bound = num.ceil(max(num.abs(miny), maxy))
if data_cpt is None:
data_cpt = '/tmp/tempfile.cpt'
gmt.makecpt(
C='blue,white,red',
Z=True,
T='%g/%g' % (-bound, bound),
out_filename=data_cpt, suppress_defaults=True)
for i, station in enumerate(stations):
logger.debug('%s, %f' % (station.station, data[i]))
st_lons = [station.lon for station in stations]
st_lats = [station.lat for station in stations]
gmt.psxy(
in_columns=(st_lons, st_lats, data.tolist()),
C=data_cpt,
*args)
if dist > 30.:
D = 'x1.25c/0c+w5c/0.5c+jMC+h'
F = False
else:
D = 'x5.5c/4.1c+w5c/0.5c+jMC+h'
F = '+gwhite'
if scale_label:
# add a colorbar
gmt.psscale(
B='xa%s +l %s' % (num.floor(bound), scale_label),
D=D,
F=F,
C=data_cpt)
else:
logger.info('Not plotting scale as "scale_label" is None')
def draw_events(gmt, events, *args, **kwargs):
ev_lons = [ev.lon for ev in events]
ev_lats = [ev.lat for ev in events]
gmt.psxy(
in_columns=(ev_lons, ev_lats),
*args, **kwargs)
def gmt_station_map_azimuthal(
gmt, stations, event, data_cpt=None,
data=None, max_distance=90, width=20, bin_width=15,
fontsize=12, font='1', plot_names=True, scale_label='time-shifts [s]'):
"""
Azimuth equidistant station map, if data given stations are colored
accordingly
Parameters
----------
gmt : :class:`pyrocko.plot.gmtpy.GMT`
stations : list
of :class:`pyrocko.model.station.Station`
event : :class:`pyrocko.model.event.Event`
data_cpt : str
path to gmt '*.cpt' file for coloring
data : :class:`numoy.NdArray`
1d vector length of stations to color stations
max_distance : float
maximum distance [deg] of event to map bound
width : float
plot width [cm]
bin_width : float
grid spacing [deg] for distance/ azimuth grid
fontsize : int
font-size in points for station labels
font : str
GMT font specification (number or name)
"""
max_distance = max_distance * 1.05 # add interval to have bound
J_basemap = 'E0/-90/%s/%i' % (max_distance, width)
J_location = 'E%s/%s/%s/%i' % (event.lon, event.lat, max_distance, width)
R_location = '0/360/-90/0'
gmt.psbasemap(
R=R_location,
J='S0/-90/90/%i' % width,
B='xa%sf%s' % (bin_width * 2, bin_width))
gmt.pscoast(
R='g',
J=J_location,
D='c',
G='darkgrey')
# plotting equal distance circles
bargs = ['-Bxg%f' % bin_width, '-Byg%f' % (2 * bin_width)]
gmt.psbasemap(
R='g', J=J_basemap, *bargs)
if data is not None:
draw_data_stations(
gmt, stations, data, max_distance, data_cpt, scale_label, *(
'-J%s' % J_location, '-R%s' % R_location, '-St14p'))
else:
st_lons = [station.lon for station in stations]
st_lats = [station.lat for station in stations]
gmt.psxy(
R=R_location,
J=J_location,
in_columns=(st_lons, st_lats),
G='red',
S='t14p')
if plot_names:
rows = []
alignment = 'TC'
for st in stations:
if gmt.is_gmt5():
row = (
st.lon, st.lat,
'%i,%s,%s' % (fontsize, font, 'black'),
alignment,
'{}.{}'.format(st.network, st.station))
farg = ['-F+f+j']
else:
raise gmtpy.GmtPyError('Only GMT version 5.x supported!')
rows.append(row)
gmt.pstext(
in_rows=rows,
R=R_location,
J=J_location,
N=True, *farg)
draw_events(
gmt, [event], *('-J%s' % J_location, '-R%s' % R_location),
**dict(G='orange', S='a14p'))
def draw_station_map_gmt(problem, po):
"""
Draws distance dependend for teleseismic vs regional/local setups
"""
if len(gmtpy.detect_gmt_installations()) < 1:
raise gmtpy.GmtPyError(
'GMT needs to be installed for station_map plot!')
if po.outformat == 'svg':
raise NotImplementedError('SVG format is not supported for this plot!')
ts = 'time_shift'
if ts in po.varnames:
logger.info('Plotting time-shifts on station locations')
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
point = get_result_point(stage.mtrace, po.post_llk)
value_string = '%i' % po.load_stage
else:
point = None
value_string = '0'
if len(po.varnames) > 0:
raise ValueError(
'Requested variables %s is not supported for plotting!'
'Supported: %s' % (utility.list2string(po.varnames), ts))
fontsize = 12
font = '1'
bin_width = 15 # major grid and tick increment in [deg]
h = 15 # outsize in cm
w = h - 5
logger.info('Drawing Station Map ...')
sc = problem.composites['seismic']
event = problem.config.event
gmtconfig = get_gmt_config(gmtpy, h=h, w=h)
gmtconfig['MAP_LABEL_OFFSET'] = '4p'
for wmap in sc.wavemaps:
outpath = os.path.join(
problem.outfolder, po.figure_dir, 'station_map_%s_%i_%s.%s' % (
wmap.name, wmap.mapnumber, value_string, po.outformat))
dist = max(wmap.config.distances)
if not os.path.exists(outpath) or po.force:
if point:
time_shifts = extract_time_shifts(point, wmap)
else:
time_shifts = None
if dist > 30:
logger.info(
'Using equidistant azimuthal projection for'
' teleseismic setup of wavemap %s.' % wmap._mapid)
gmt = gmtpy.GMT(config=gmtconfig)
gmt_station_map_azimuthal(
gmt, wmap.stations, event,
data=time_shifts, max_distance=dist, width=w,
bin_width=bin_width,
fontsize=fontsize, font=font)
gmt.save(outpath, resolution=po.dpi, size=w)
else:
logger.info(
'Using equidistant projection for regional setup '
'of wavemap %s.' % wmap._mapid)
from pyrocko.automap import Map
m = Map(
lat=event.lat,
lon=event.lon,
radius=dist * otd.d2m,
width=h,
height=h,
show_grid=True,
show_topo=True,
show_scale=True,
color_dry=(143, 188, 143), # grey
illuminate=True,
illuminate_factor_ocean=0.15,
# illuminate_factor_land = 0.2,
show_rivers=True,
show_plates=False,
gmt_config=gmtconfig)
if time_shifts:
sargs = m.jxyr + ['-St14p']
draw_data_stations(
m.gmt, wmap.stations, time_shifts, dist,
data_cpt=None, scale_label='time shifts [s]', *sargs)
for st in wmap.stations:
text = '{}.{}'.format(st.network, st.station)
m.add_label(lat=st.lat, lon=st.lon, text=text)
else:
m.add_stations(
wmap.stations, psxy_style=dict(S='t14p', G='red'))
draw_events(
m.gmt, [event], *m.jxyr, **dict(G='yellow', S='a14p'))
m.save(outpath, resolution=po.dpi, oversample=2., size=w)
logger.info('saving figure to %s' % outpath)
else:
logger.info('Plot exists! Use --force to overwrite!')
def draw_3d_slip_distribution(problem, po):
varname_choices = ['coupling', 'slip_deficit', 'slip_variation']
if po.outformat == 'svg':
raise NotImplementedError('SVG format is not supported for this plot!')
mode = problem.config.problem_config.mode
if mode != ffi_mode_str:
raise ModeError(
'Wrong optimization mode: %s! This plot '
'variant is only valid for "%s" mode' % (mode, ffi_mode_str))
if po.load_stage is None:
po.load_stage = -1
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
if not po.reference:
reference = problem.config.problem_config.get_test_point()
res_point = get_result_point(stage.mtrace, po.post_llk)
reference.update(res_point)
llk_str = po.post_llk
mtrace = stage.mtrace
else:
reference = po.reference
llk_str = 'ref'
mtrace = None
datatype, cconf = list(problem.composites.items())[0]
fault = cconf.load_fault_geometry()
if po.plot_projection in ['local', 'latlon']:
perspective = '135/30'
else:
perspective = po.plot_projection
gc = problem.config.geodetic_config
if gc:
for corr in gc.corrections_config.euler_poles:
if corr.enabled:
if len(po.varnames) > 0 and po.varnames[0] in varname_choices:
from beat.ffi import backslip2coupling
logger.info('Plotting %s ...!', po.varnames[0])
reference['coupling'] = backslip2coupling(
point=reference, fault=fault,
event=problem.config.event)
# TODO: cleanup iforgy with slip units etc ...
if po.varnames[0] == 'coupling':
slip_units = '%'
else:
slip_units = 'm/yr'
else:
logger.info(
'Found Euler pole correction assuming interseismic '
'slip-rates ...')
slip_units = 'm/yr'
else:
logger.info(
'Did not find Euler pole correction-assuming '
'co-seismic slip ...')
slip_units = 'm'
if po.varnames[0] == 'slip_variation':
from pandas import read_csv
from beat.backend import extract_bounds_from_summary
summarydf = read_csv(
os.path.join(problem.outfolder, 'summary.txt'), sep='\s+')
bounds = extract_bounds_from_summary(
summarydf, varname='uparr', shape=(fault.npatches,))
reference['slip_variation'] = bounds[1] - bounds[0]
slip_units = 'm'
if len(po.varnames) == 0:
varnames = None
else:
varnames = po.varnames
if len(po.varnames) == 1:
slip_label = po.varnames[0]
else:
slip_label = 'slip'
if po.source_idxs is None:
source_idxs = [0, fault.nsubfaults]
else:
source_idxs = po.source_idxs
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'3d_%s_distribution_%i_%s_%i.%s' % (
slip_label, po.load_stage, llk_str,
po.nensemble, po.outformat))
if not os.path.exists(outpath) or po.force or po.outformat == 'display':
logger.info('Drawing 3d slip-distribution plot ...')
gmt = slip_distribution_3d_gmt(
fault, reference, mtrace, perspective,
slip_units, slip_label, varnames, source_idxs=source_idxs)
logger.info('saving figure to %s' % outpath)
gmt.save(outpath, resolution=300, size=10)
else:
logger.info('Plot exists! Use --force to overwrite!')
def slip_distribution_3d_gmt(
fault, reference, mtrace=None, perspective='135/30', slip_units='m',
slip_label='slip', varnames=None, gmt=None, bin_width=1,
cptfilepath=None, transparency=0, source_idxs=None):
if len(gmtpy.detect_gmt_installations()) < 1:
raise gmtpy.GmtPyError(
'GMT needs to be installed for station_map plot!')
p = 'z%s/0' % perspective
# bin_width = 1 # major grid and tick increment in [deg]
if gmt is None:
font_size = 12
font = '1'
h = 15 # outsize in cm
w = 22
gmtconfig = get_gmt_config(gmtpy, h=h, w=w, fontsize=11)
gmtconfig['MAP_FRAME_TYPE'] = 'plain'
gmtconfig['MAP_SCALE_HEIGHT'] = '11p'
#gmtconfig.pop('PS_MEDIA')
gmt = gmtpy.GMT(config=gmtconfig)
sf_lonlats = num.vstack(
[sf.outline(cs='lonlat') for sf in fault.iter_subfaults(source_idxs)])
sf_xyzs = num.vstack(
[sf.outline(cs='xyz') for sf in fault.iter_subfaults(source_idxs)])
_, _, max_depth = sf_xyzs.max(axis=0) / km
lon_min, lat_min = sf_lonlats.min(axis=0)
lon_max, lat_max = sf_lonlats.max(axis=0)
lon_tolerance = (lon_max - lon_min) * 0.1
lat_tolerance = (lat_max - lat_min) * 0.1
R = utility.list2string(
[lon_min - lon_tolerance,
lon_max + lon_tolerance,
lat_min - lat_tolerance,
lat_max + lat_tolerance,
-max_depth, 0], '/')
Jg = '-JM%fc' % 20
Jz = '-JZ%gc' % 3
J = [Jg, Jz]
B = ['-Bxa%gg%g' % (bin_width, bin_width),
'-Bya%gg%g' % (bin_width, bin_width),
'-Bza10+Ldepth [km]', '-BWNesZ']
args = J + B
gmt.pscoast(
R=R,
D='a',
G='gray90',
S='lightcyan',
p=p,
*J)
gmt.psbasemap(
R=R,
p=p,
*args)
if slip_label == 'coupling':
reference_slips = reference['coupling'] * 100 # in percent
elif slip_label == 'slip_deficit':
reference_slips = reference['coupling'] * fault.get_total_slip(
index=None, point=reference)
elif slip_label == 'slip_variation':
reference_slips = reference[slip_label]
else:
reference_slips = fault.get_total_slip(
index=None, point=reference, components=varnames)
autos = AutoScaler(snap='on', approx_ticks=3)
cmin, cmax, cinc = autos.make_scale(
(0, reference_slips.max()), override_mode='min-max')
if cptfilepath is None:
cptfilepath = '/tmp/tempfile.cpt'
gmt.makecpt(
C='hot',
I='c',
T='%f/%f' % (cmin, cmax),
out_filename=cptfilepath, suppress_defaults=True)
tmp_patch_fname = '/tmp/temp_patch.txt'
for idx in range(*source_idxs):
slips = fault.vector2subfault(index=idx, vector=reference_slips)
for i, source in enumerate(fault.get_subfault_patches(idx)):
lonlats = source.outline(cs='lonlat')
xyzs = source.outline(cs='xyz') / km
depths = xyzs[:, 2] * -1. # make depths negative
in_rows = num.hstack((lonlats, num.atleast_2d(depths).T))
num.savetxt(
tmp_patch_fname,
in_rows,
header='> -Z%f' % slips[i],
comments='')
gmt.psxyz(
tmp_patch_fname,
R=R,
C=cptfilepath,
L=True,
t=transparency,
W='0.1p',
p=p,
*J)
# add a colorbar
azimuth, elev_angle = perspective.split('/')
if float(azimuth) < 180:
ypos = 0
else:
ypos = 10
D = 'x1.5c/%ic+w6c/0.5c+jMC+h' % ypos
F = False
gmt.psscale(
B='xa%f +l %s [%s]' % (cinc, slip_label, slip_units),
D=D,
F=F,
C=cptfilepath,
finish=True)
return gmt
def draw_lune_plot(problem, po):
if po.outformat == 'svg':
raise NotImplementedError('SVG format is not supported for this plot!')
if problem.config.problem_config.n_sources > 1:
raise NotImplementedError(
'Lune plot is not yet implemented for more than one source!')
if po.load_stage is None:
po.load_stage = -1
stage = load_stage(
problem, stage_number=po.load_stage, load='trace', chains=[-1])
n_mts = len(stage.mtrace)
result_ensemble = {}
for varname in ['v', 'w']:
try:
result_ensemble[varname] = stage.mtrace.get_values(
varname, combine=True, squeeze=True).ravel()
except ValueError: # if fixed value add that to the ensemble
rpoint = problem.get_random_point()
result_ensemble[varname] = num.full_like(
num.empty((n_mts), dtype=num.float64), rpoint[varname])
if po.reference:
reference_v_tape = po.reference['v']
reference_w_tape = po.reference['w']
llk_str = 'ref'
else:
reference_v_tape = None
reference_w_tape = None
llk_str = po.post_llk
outpath = os.path.join(
problem.outfolder,
po.figure_dir,
'lune_%i_%s_%i.%s' % (
po.load_stage, llk_str, po.nensemble, po.outformat))
if po.nensemble > 1:
logger.info('Plotting selected ensemble as nensemble > 1 ...')
selected = num.linspace(
0, n_mts, po.nensemble, dtype='int', endpoint=False)
v_tape = result_ensemble['v'][selected]
w_tape = result_ensemble['w'][selected]
else:
logger.info('Plotting whole posterior ...')
v_tape = result_ensemble['v']
w_tape = result_ensemble['w']
if not os.path.exists(outpath) or po.force or po.outformat == 'display':
logger.info('Drawing Lune plot ...')
gmt = lune_plot(
v_tape=v_tape, w_tape=w_tape,
reference_v_tape=reference_v_tape,
reference_w_tape=reference_w_tape)
logger.info('saving figure to %s' % outpath)
gmt.save(outpath, resolution=300, size=10)
else:
logger.info('Plot exists! Use --force to overwrite!')
def lune_plot(
v_tape=None, w_tape=None,
reference_v_tape=None, reference_w_tape=None):
from beat.sources import v_to_gamma, w_to_delta
if len(gmtpy.detect_gmt_installations()) < 1:
raise gmtpy.GmtPyError(
'GMT needs to be installed for lune_plot!')
fontsize = 14
font = '1'
def draw_lune_arcs(gmt, R, J):
lons = [30., -30., 30., -30.]
lats = [54.7356, 35.2644, -35.2644, -54.7356]
gmt.psxy(
in_columns=(lons, lats), N=True, W='1p,black', R=R, J=J)
def draw_lune_points(gmt, R, J, labels=True):
lons = [0., -30., -30., -30., 0., 30., 30., 30., 0.]
lats = [-90., -54.7356, 0., 35.2644, 90., 54.7356, 0., -35.2644, 0.]
annotations = [
'-ISO', '', '+CLVD', '+LVD', '+ISO', '', '-CLVD', '-LVD', 'DC']
alignments = ['TC', 'TC', 'RM', 'RM', 'BC', 'BC', 'LM', 'LM', 'TC']
gmt.psxy(in_columns=(lons, lats), N=True, S='p6p', W='1p,0', R=R, J=J)
rows = []
if labels:
farg = ['-F+f+j']
for lon, lat, text, align in zip(
lons, lats, annotations, alignments):
rows.append((
lon, lat,
'%i,%s,%s' % (fontsize, font, 'black'),
align, text))
gmt.pstext(
in_rows=rows,
N=True, R=R, J=J, D='j5p', *farg)
def draw_lune_kde(
gmt, v_tape, w_tape, grid_size=(200, 200), R=None, J=None):
def check_fixed(a, varname):
if a.std() < 0.1:
logger.info(
'Spread of variable "%s" is %f, which is below necessary'
' width to estimate a spherical kde, adding some jitter to'
' make kde estimate possible' % (varname, a.std()))
a += num.random.normal(loc=0., scale=0.05, size=a.size)
gamma = num.rad2deg(v_to_gamma(v_tape)) # lune longitude [rad]
delta = num.rad2deg(w_to_delta(w_tape)) # lune latitude [rad]
check_fixed(gamma, varname='v')
check_fixed(delta, varname='w')
lats_vec, lats_inc = num.linspace(
-90., 90., grid_size[0], retstep=True)
lons_vec, lons_inc = num.linspace(
-30., 30., grid_size[1], retstep=True)
lons, lats = num.meshgrid(lons_vec, lats_vec)
kde_vals, _, _ = spherical_kde_op(
lats0=delta, lons0=gamma,
lons=lons, lats=lats, grid_size=grid_size)
Tmin = num.min([0., kde_vals.min()])
Tmax = num.max([0., kde_vals.max()])
cptfilepath = '/tmp/tempfile.cpt'
gmt.makecpt(
C='white,yellow,orange,red,magenta,violet',
Z=True, D=True,
T='%f/%f' % (Tmin, Tmax),
out_filename=cptfilepath, suppress_defaults=True)
grdfile = gmt.tempfilename()
gmt.xyz2grd(
G=grdfile, R=R, I='%f/%f' % (lons_inc, lats_inc),
in_columns=(lons.ravel(), lats.ravel(), kde_vals.ravel()), # noqa
out_discard=True)
gmt.grdimage(grdfile, R=R, J=J, C=cptfilepath)
# gmt.pscontour(
# in_columns=(lons.ravel(), lats.ravel(), kde_vals.ravel()),
# R=R, J=J, I=True, N=True, A=True, C=cptfilepath)
# -Ctmp_$out.cpt -I -N -A- -O -K >> $ps
def draw_reference_lune(gmt, R, J, reference_v_tape, reference_w_tape):
gamma = num.rad2deg(
v_to_gamma(reference_v_tape)) # lune longitude [rad]
delta = num.rad2deg(
w_to_delta(reference_w_tape)) # lune latitude [rad]
gmt.psxy(
in_rows=[(float(gamma), float(delta))],
N=True, G='blue', W='1p,black', S='p3p', R=R, J=J)
h = 20.
w = h / 1.9
gmtconfig = get_gmt_config(gmtpy, h=h, w=w)
bin_width = 15 # tick increment
J = 'H0/%f' % (w - 5.)
R = '-30/30/-90/90'
B = 'f%ig%i/f%ig%i' % (bin_width, bin_width, bin_width, bin_width)
# range_arg="-T${zmin}/${zmax}/${dz}"
gmt = gmtpy.GMT(config=gmtconfig)
draw_lune_kde(
gmt, v_tape=v_tape, w_tape=w_tape, grid_size=(701, 301), R=R, J=J)
gmt.psbasemap(R=R, J=J, B=B)
draw_lune_arcs(gmt, R=R, J=J)
draw_lune_points(gmt, R=R, J=J)
if reference_v_tape is not None:
draw_reference_lune(
gmt, R=R, J=J,
reference_v_tape=reference_v_tape,
reference_w_tape=reference_w_tape)
return gmt
def draw_station_map_cartopy(problem, po):
import matplotlib.ticker as mticker
logger.info('Drawing Station Map ...')
try:
import cartopy as ctp
except ImportError:
logger.error(
'Cartopy is not installed.'
'For a station map cartopy needs to be installed!')
return
def draw_gridlines(ax):
gl = ax.gridlines(crs=grid_proj, color='black', linewidth=0.5)
gl.n_steps = 300
gl.xlines = False
gl.ylocator = mticker.FixedLocator([30, 60, 90])
fontsize = 12
if 'seismic' not in problem.config.problem_config.datatypes:
raise TypeError(
'Station map is available only for seismic stations!'
' However, the datatypes do not include "seismic" data')
event = problem.config.event
sc = problem.composites['seismic']
mpl_init(fontsize=fontsize)
stations_proj = ctp.crs.PlateCarree()
for wmap in sc.wavemaps:
outpath = os.path.join(
problem.outfolder, po.figure_dir, 'station_map_%s_%i.%s' % (
wmap.name, wmap.mapnumber, po.outformat))
if not os.path.exists(outpath) or po.force:
if max(wmap.config.distances) > 30:
map_proj = ctp.crs.Orthographic(
central_longitude=event.lon, central_latitude=event.lat)
extent = None
else:
max_dist = math.ceil(wmap.config.distances[1])
map_proj = ctp.crs.PlateCarree()
extent = [
event.lon - max_dist, event.lon + max_dist,
event.lat - max_dist, event.lat + max_dist]
grid_proj = ctp.crs.RotatedPole(
pole_longitude=event.lon, pole_latitude=event.lat)
fig, ax = plt.subplots(
nrows=1, ncols=1, figsize=mpl_papersize('a6', 'landscape'),
subplot_kw={'projection': map_proj})
stations_meta = [
(station.lat, station.lon, station.station)
for station in wmap.stations]
if extent:
# regional map
labelpos = mpl_margins(
fig, left=2, bottom=2, top=2, right=2, units=fontsize)
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import \
LONGITUDE_FORMATTER, LATITUDE_FORMATTER
ax.set_extent(extent, crs=map_proj)
ax.add_feature(cfeature.NaturalEarthFeature(
category='physical', name='land',
scale='50m', **cfeature.LAND.kwargs))
ax.add_feature(cfeature.NaturalEarthFeature(
category='physical', name='ocean',
scale='50m', **cfeature.OCEAN.kwargs))
gl = ax.gridlines(
color='black', linewidth=0.5, draw_labels=True)
gl.ylocator = tick.MaxNLocator(nbins=5)
gl.xlocator = tick.MaxNLocator(nbins=5)
gl.xlabels_top = False
gl.ylabels_right = False
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
else:
# global teleseismic map
labelpos = mpl_margins(
fig, left=1, bottom=1, top=1, right=1, units=fontsize)
ax.coastlines(linewidth=0.2)
draw_gridlines(ax)
ax.stock_img()
for (lat, lon, name) in stations_meta:
ax.plot(
lon, lat, 'r^', transform=stations_proj,
markeredgecolor='black', markeredgewidth=0.3)
ax.text(
lon, lat, name, fontsize=10, transform=stations_proj,
horizontalalignment='center', verticalalignment='top')
ax.plot(
event.lon, event.lat, '*', transform=stations_proj,
markeredgecolor='black', markeredgewidth=0.3, markersize=12,
markerfacecolor=scolor('butter1'))
if po.outformat == 'display':
plt.show()
else:
logger.info('saving figure to %s' % outpath)
fig.savefig(outpath, format=po.outformat, dpi=po.dpi)
else:
logger.info('Plot exists! Use --force to overwrite!')
plots_catalog = {
'correlation_hist': draw_correlation_hist,
'stage_posteriors': draw_posteriors,
'waveform_fits': draw_seismic_fits,
'scene_fits': draw_scene_fits,
'gnss_fits': draw_gnss_fits,
'velocity_models': draw_earthmodels,
'slip_distribution': draw_slip_dist,
'slip_distribution_3d': draw_3d_slip_distribution,
'hudson': draw_hudson,
'lune': draw_lune_plot,
'fuzzy_beachball': draw_fuzzy_beachball,
'fuzzy_mt_decomp': draw_fuzzy_mt_decomposition,
'moment_rate': draw_moment_rate,
'station_map': draw_station_map_gmt}
common_plots = [
'stage_posteriors',]
seismic_plots = [
'station_map',
'waveform_fits',
'fuzzy_mt_decomp',
'hudson',
'lune',
'fuzzy_beachball']
geodetic_plots = [
'scene_fits',
'gnss_fits']
geometry_plots = [
'correlation_hist',
'velocity_models']
ffi_plots = [
'moment_rate',
'slip_distribution']
plots_mode_catalog = {
'geometry': common_plots + geometry_plots,
'ffi': common_plots + ffi_plots,
}
plots_datatype_catalog = {
'seismic': seismic_plots,
'geodetic': geodetic_plots,
}
def available_plots(mode=None, datatypes=['geodetic', 'seismic']):
if mode is None:
return list(plots_catalog.keys())
else:
plots = plots_mode_catalog[mode]
for datatype in datatypes:
plots.extend(plots_datatype_catalog[datatype])
return plots