Source code for backend

"""
File trace backends modified from pymc to work efficiently with
SMC

Store sampling values as CSV or binary files.

File format
-----------

Sampling values for each chain are saved in a separate file (under a
directory specified by the `dir_path` argument).  The rows correspond to
sampling iterations.  The column names consist of variable names and
index labels.  For example, the heading

  x,y__0_0,y__0_1,y__1_0,y__1_1,y__2_0,y__2_1

represents two variables, x and y, where x is a scalar and y has a
shape of (3, 2).
"""
import copy
import itertools
import json
import logging
import os
import shutil
from collections import OrderedDict
from glob import glob
from time import time

import numpy as num
import pandas as pd
from pandas.errors import EmptyDataError

# pandas version control
try:
    from pandas.io.common import CParserError
except ImportError:
    from pandas.errors import ParserError as CParserError

from typing import (
    Set,
)

from arviz import convert_to_inference_data

# from arviz.data.base import dict_to_dataset
from pymc.backends import base, ndarray
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.model import modelcontext
from pymc.step_methods.arraystep import BlockedStep
from pyrocko import util

from beat.config import sample_p_outname, transd_vars_dist
from beat.covariance import calc_sample_covariance
from beat.utility import (
    ListArrayOrdering,
    ListToArrayBijection,
    dump_objects,
    load_objects,
)

logger = logging.getLogger("backend")


def _create_flat_names(varname, shape):
    """Return flat variable names for `varname` of `shape`.

    Examples
    --------
    >>> _create_flat_names('x', (5,))
    ['x__0', 'x__1', 'x__2', 'x__3', 'x__4']

    >>> _create_flat_names('x', (2, 2))
    ['x__0_0', 'x__0_1', 'x__1_0', 'x__1_1']
    """
    if not shape:
        return [varname]
    labels = (num.ravel(xs).tolist() for xs in num.indices(shape))
    labels = (map(str, xs) for xs in labels)
    return ["{}__{}".format(varname, "_".join(idxs)) for idxs in zip(*labels)]


def _create_flat_names_summary(varname, shape):
    if not shape or sum(shape) == 1:
        return [varname]

    labels = (num.ravel(xs).tolist() for xs in num.indices(shape))
    labels = (map(str, [xs]) for xs in labels)
    return ["{}{}".format(varname, "".join(idxs)) for idxs in zip(*labels)]


def _create_shape(flat_names):
    """Determine shape from `_create_flat_names` output."""
    try:
        _, shape_str = flat_names[-1].rsplit("__", 1)
    except ValueError:
        return ()
    return tuple(int(i) + 1 for i in shape_str.split("_"))


[docs] def thin_buffer(buffer, buffer_thinning, ensure_last=True): """ Reduce a list of objects by a given value. Parameters ---------- buffer : list of objects to be thinned buffer_thinning : int every nth object in list is returned ensure_last : bool enable to ensure that last object in list is returned """ if ensure_last: write_buffer = buffer[-1::-buffer_thinning] write_buffer.reverse() else: write_buffer = buffer[::buffer_thinning] return write_buffer
[docs] class ArrayStepSharedLLK(BlockedStep): """ Modified ArrayStepShared To handle returned larger point including the likelihood values. Takes additionally a list of output vars including the likelihoods. Parameters ---------- value_vars : list variables to be sampled out_vars : list variables to be stored in the traces shared : dict pytensor variable -> shared variables """ def __init__(self, value_vars, out_vars, shared): self.value_vars = value_vars self.lordering = ListArrayOrdering(out_vars, intype="tensor") lpoint = [var.tag.test_value for var in out_vars] self.shared = {var.name: shared for var, shared in shared.items()} self.blocked = True blacklist = list( set(self.lordering.variables) - set([var.name for var in value_vars]) ) self.bij = DictToArrayBijection() self.lij = ListToArrayBijection(self.lordering, lpoint, blacklist=blacklist) def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state)
[docs] def step(self, point): for name, shared_var in self.shared.items(): shared_var.set_value(point[name]) # print("point", point) # assure order and content of RVs consistent to value_vars point = {val_var.name: point[val_var.name] for val_var in self.value_vars} q = self.bij.map(point) # print("before", q.data) apoint, alist = self.astep(q.data) # print("after", apoint, alist) if not isinstance(apoint, RaveledVars): # We assume that the mapping has stayed the same apoint = RaveledVars(apoint, q.point_map_info) return self.bij.rmap(apoint, start_point=point), alist
[docs] class BaseChain(object): """ Base chain object, independent of file or memory output. Parameters ---------- model : Model If None, the model is taken from the `with` context. value_vars : list of variables Sampling values will be stored for these variables. If None, `model.unobserved_RVs` is used. """ def __init__( self, model=None, value_vars=None, buffer_size=5000, buffer_thinning=1 ): self.var_shapes = None self.chain = None self.buffer_size = buffer_size self.buffer_thinning = buffer_thinning self.buffer = [] self.count = 0 self.cov_counter = 0 if model is not None: model = modelcontext(model) if value_vars is None and model is not None: value_vars = model.unobserved_RVs if value_vars is not None: # Get variable shapes. Most backends will need this # information. self.var_shapes = OrderedDict() self.var_dtypes = OrderedDict() self.varnames = [] for var in value_vars: self.var_shapes[var.name] = var.tag.test_value.shape self.var_dtypes[var.name] = var.tag.test_value.dtype self.varnames.append(var.name) else: logger.debug("No model or variables given!") def __getitem__(self, idx): if isinstance(idx, slice): return self._slice(idx) try: return self.point(int(idx)) except (ValueError, TypeError): # Passed variable or variable name. raise ValueError("Can only index with slice or integer") def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state)
[docs] def buffer_write(self, lpoint, draw): """ Write sampling results into buffer. If buffer is full trow an error. """ self.count += 1 self.buffer.append((lpoint, draw)) if self.count == self.buffer_size: raise BufferError("Buffer is full! Needs recording!!")
def empty_buffer(self): self.buffer = [] self.count = 0
[docs] def get_sample_covariance(self, step): """ Return sample Covariance matrix from buffer. """ sample_difference = self.count - self.buffer_size if sample_difference < 0: raise ValueError("Covariance has been updated already!") elif sample_difference > 0: raise BufferError("Buffer is not full and sample covariance may be biased") else: logger.info( "Evaluating sampled trace covariance of worker %i at " "sample %i" % (self.chain, step.cumulative_samples) ) cov = calc_sample_covariance( self.buffer, lij=step.lij, bij=step.bij, beta=step.beta ) self.cov_counter += 1 return cov
[docs] class FileChain(BaseChain): """ Base class for a trace written to a file with buffer functionality and rogressbar. Buffer is a list of tuples of lpoints and a draw index. Inheriting classes must define the methods: '_write_data_to_file' and '_load_df' """ def __init__( self, dir_path="", model=None, value_vars=None, buffer_size=5000, buffer_thinning=1, progressbar=False, k=None, ): super(FileChain, self).__init__( model=model, value_vars=value_vars, buffer_size=buffer_size, buffer_thinning=buffer_thinning, ) if not os.path.exists(dir_path): os.mkdir(dir_path) self.dir_path = dir_path self.flat_names = OrderedDict() if self.var_shapes is not None: if k is not None: for var, shape in self.var_shapes.items(): if var in transd_vars_dist: shape = (k,) self.flat_names[var] = _create_flat_names(var, shape) else: for var, shape in self.var_shapes.items(): self.flat_names[var] = _create_flat_names(var, shape) self.k = k self.corrupted_flag = False self.progressbar = progressbar self.stored_samples = 0 self.draws = 0 self._df = None self.filename = None self.derived_mapping = None def __len__(self): if self.filename is None: return 0 self._load_df() if self._df is None: return 0 else: return self._df.shape[0] + len(self.buffer) def add_derived_variables(self, varnames, shapes): nshapes = len(shapes) nvars = len(varnames) if nvars != nshapes: raise ValueError( "Inconsistent number of variables %i and shapes %i!" % (nvars, nshapes) ) self.derived_mapping = {} for varname, shape in zip(varnames, shapes): if varname in self.varnames: exist_idx = self.varnames.index(varname) self.varnames.pop(exist_idx) exist_shape = self.var_shapes[varname] shape = tuple(map(sum, zip(exist_shape, shape))) concat_idx = len(self.varnames) self.derived_mapping[exist_idx] = concat_idx self.flat_names[varname] = _create_flat_names(varname, shape) self.var_shapes[varname] = shape self.var_dtypes[varname] = "float64" self.varnames.append(varname) def _load_df(self): raise ValueError("This method must be defined in inheriting classes!") def _write_data_to_file(self): raise ValueError("This method must be defined in inheriting classes!") def data_file(self): return self._df def record_buffer(self): if self.chain is None: raise ValueError("Chain has not been setup. Saving samples not possible!") else: n_samples = len(self.buffer) self.stored_samples += n_samples if not self.progressbar: if n_samples > self.buffer_size // 2: logger.info( "Writing %i / %i samples of chain %i to disk..." % (self.stored_samples, self.draws, self.chain) ) t0 = time() logger.debug("Start Record: Chain_%i" % self.chain) self._write_data_to_file() t1 = time() logger.debug("End Record: Chain_%i" % self.chain) logger.debug("Writing to file took %f" % (t1 - t0)) self.empty_buffer()
[docs] def write(self, lpoint, draw): """ Write sampling results into buffer. If buffer is full write samples to file. """ self.count += 1 if self.derived_mapping: for exist_idx, concat_idx in self.derived_mapping.items(): value = lpoint.pop(exist_idx) lpoint[concat_idx] = num.hstack((value, lpoint[concat_idx])) self.buffer.append((lpoint, draw)) if self.count == self.buffer_size: self.record_buffer()
[docs] def clear_data(self): """ Clear the data loaded from file. """ self._df = None
def _get_sampler_stats( self, stat_name: str, sampler_idx: int, burn: int, thin: int ) -> num.ndarray: """Get sampler statistics.""" raise NotImplementedError() @property def stat_names(self) -> Set[str]: names: Set[str] = set() for vars in self.sampler_vars or []: names.update(vars.keys()) return names
[docs] class MemoryChain(BaseChain): """ Slim memory trace object. Keeps points in a list in memory. Parameters ---------- draws : int Number of samples chain : int Chain number """ def __init__(self, buffer_size=5000): super(MemoryChain, self).__init__(buffer_size=buffer_size) def setup(self, draws, chain, overwrite=False): self.draws = draws self.chain = chain if self.buffer is None: self.buffer = [] if overwrite: self.buffer = [] def record_buffer(self): logger.debug("Emptying buffer of trace %i" % self.chain) self.empty_buffer()
[docs] class TextChain(FileChain): """ Text trace object based on '.csv' files. Slow in reading and writing. Good for debugging. Parameters ---------- dir_path : str Name of directory to store text files model : Model If None, the model is taken from the `with` context. vars : list of variables Sampling values will be stored for these variables. If None, `model.unobserved_RVs` is used. buffer_size : int this is the number of samples after which the buffer is written to disk or if the chain end is reached buffer_thinning : int every nth sample of the buffer is written to disk progressbar : boolean flag if a progressbar is active, if not a logmessage is printed every time the buffer is written to disk k : int, optional if given dont use shape from testpoint as size of transd variables """ def __init__( self, dir_path, model=None, value_vars=None, buffer_size=5000, buffer_thinning=1, progressbar=False, k=None, ): super(TextChain, self).__init__( dir_path, model, value_vars, buffer_size=buffer_size, progressbar=progressbar, k=k, buffer_thinning=buffer_thinning, )
[docs] def setup(self, draws, chain, overwrite=False): """ Perform chain-specific setup. Parameters ---------- draws : int Expected number of draws chain : int Chain number """ logger.debug("SetupTrace: Chain_%i step_%i" % (chain, draws)) self.chain = chain self.draws = draws self.filename = os.path.join(self.dir_path, "chain-{}.csv".format(chain)) cnames = [fv for v in self.varnames for fv in self.flat_names[v]] if os.path.exists(self.filename) and not overwrite: logger.debug("Found existing trace, appending!") else: self.count = 0 # writing header with open(self.filename, "w") as fh: fh.write(",".join(cnames) + "\n")
def _write_data_to_file(self, lpoint=None): """ Write the lpoint to file. If lpoint is None it will try to write from buffer. Parameters ---------- lpoint: list of numpy arrays """ def lpoint2file(filehandle, lpoint): columns = itertools.chain.from_iterable( map(str, value.ravel()) for value in lpoint ) # print("backend write", columns) filehandle.write(",".join(columns) + "\n") # Write binary if lpoint is None and len(self.buffer) == 0: logger.debug("There is no data to write into file.") try: with open(self.filename, mode="a+") as fh: if lpoint is None: # write out thinned buffer starting with last sample write_buffer = thin_buffer( self.buffer, self.buffer_thinning, ensure_last=True ) for lpoint, draw in write_buffer: lpoint2file(fh, lpoint) else: lpoint2file(fh, lpoint) except EnvironmentError as e: print("Error on write file: ", e) def _load_df(self): if self._df is None: try: self._df = pd.read_csv(self.filename) except EmptyDataError: logger.warning( "Trace %s is empty and needs to be resampled!" % self.filename ) os.remove(self.filename) self.corrupted_flag = True except CParserError: logger.warning("Trace %s has wrong size!" % self.filename) self.corrupted_flag = True os.remove(self.filename) if len(self.flat_names) == 0 and not self.corrupted_flag: self.flat_names, self.var_shapes = extract_variables_from_df(self._df) self.varnames = list(self.var_shapes.keys())
[docs] def get_values(self, varname, burn=0, thin=1): """ Get values from trace. Parameters ---------- varname : str Variable name for which values are to be retrieved. burn : int Burn-in samples from trace. This is the number of samples to be thrown out from the start of the trace thin : int Number of thinning samples. Throw out every 'thin' sample of the trace. Returns ------- :class:`numpy.array` """ self._load_df() try: var_df = self._df[self.flat_names[varname]] shape = (self._df.shape[0],) + self.var_shapes[varname] vals = var_df.values.ravel().reshape(shape) return vals[burn::thin] except KeyError: raise ValueError( 'Did not find varname "%s" in sampling ' "results! Fixed?" % varname )
def _slice(self, idx): if idx.stop is not None: raise ValueError("Stop value in slice not supported.") return ndarray._slice_as_ndarray(self, idx)
[docs] def point(self, idx): """ Get point of current chain with variables names as keys. Parameters ---------- idx : int Index of the nth step of the chain Returns ------- dictionary of point values """ idx = int(idx) self._load_df() pt = {} for varname in self.varnames: # needs deepcopy otherwise reference to df is kept repetead calls # lead to memory leak vals = self._df[self.flat_names[varname]].iloc[idx] pt[varname] = copy.deepcopy(vals.values.reshape(self.var_shapes[varname])) del vals return pt
[docs] class NumpyChain(FileChain): """ Numpy binary trace object based on '.bin' files. Fast in reading and writing. Bad for debugging. Parameters ---------- dir_path : str Name of directory to store text files model : Model If None, the model is taken from the `with` context. vars : list of variables Sampling values will be stored for these variables. If None, `model.unobserved_RVs` is used. buffer_size : int this is the number of samples after which the buffer is written to disk or if the chain end is reached buffer_thinning : int every nth sample of the buffer is written to disk progressbar : boolean flag if a progressbar is active, if not a logmessage is printed every time the buffer is written to disk k : int, optional if given dont use shape from testpoint as size of transd variables """ flat_names_tag = "flat_names" var_shape_tag = "var_shapes" var_dtypes_tag = "var_dtypes" __data_structure = None def __init__( self, dir_path, model=None, value_vars=None, buffer_size=5000, progressbar=False, k=None, buffer_thinning=1, ): super(NumpyChain, self).__init__( dir_path, model, value_vars, progressbar=progressbar, buffer_size=buffer_size, buffer_thinning=buffer_thinning, k=k, ) self.k = k def __repr__(self): return "NumpyChain({},{},{},{},{},{})".format( self.dir_path, self.model, self.value_vars, self.buffer_size, self.progressbar, self.k, ) @property def data_structure(self): return self.__data_structure @property def file_header(self): with open(self.filename, mode="rb") as file: # read header. file_header = file.readline().decode() return file_header
[docs] def setup(self, draws, chain, overwrite=False): """ Perform chain-specific setup. Creates file with header. If exist not overwritten again unless flag is set. Parameters ---------- draws: int. Expected number of draws chain: int. Chain number overwrite: Bool (optional). True(default) if file need to be overwrite, false otherwise. """ logger.debug("SetupTrace: Chain_%i step_%i" % (chain, draws)) self.chain = chain self.draws = draws self.filename = os.path.join(self.dir_path, "chain-{}.bin".format(chain)) self.__data_structure = self.construct_data_structure() if os.path.exists(self.filename) and not overwrite: logger.info("Found existing trace, appending!") else: logger.debug('Setup new "bin" trace for chain %i' % chain) self.count = 0 data_type = OrderedDict() with open(self.filename, "wb") as fh: for k, v in self.var_dtypes.items(): data_type[k] = "{}".format(v) header_data = { self.flat_names_tag: self.flat_names, self.var_shape_tag: self.var_shapes, self.var_dtypes_tag: data_type, } header = (json.dumps(header_data) + "\n").encode() fh.write(header)
def extract_variables_from_header(self, file_header): header_data = json.loads(file_header, object_pairs_hook=OrderedDict) flat_names = header_data[self.flat_names_tag] var_shapes = OrderedDict() for k, v in header_data[self.var_shape_tag].items(): var_shapes[k] = tuple(v) var_dtypes = header_data[self.var_dtypes_tag] varnames = list(flat_names.keys()) return flat_names, var_shapes, var_dtypes, varnames
[docs] def construct_data_structure(self): """ Create a dtype to store the data based on varnames in a numpy array. Returns ------- A numpy.dtype """ if len(self.flat_names) == 0 and not self.corrupted_flag: ( self.flat_names, self.var_shapes, self.var_dtypes, self.varnames, ) = self.extract_variables_from_header(self.file_header) formats = [ "{shape}{dtype}".format( shape=self.var_shapes[name], dtype=self.var_dtypes[name] ) for name in self.varnames ] # set data structure return num.dtype({"names": self.varnames, "formats": formats})
def _write_data_to_file(self, lpoint=None): """ Writes lpoint to file. If lpoint is None it will try to write from buffer. Parameters ---------- lpoint: list of numpy arrays. """ def lpoint2file(filehandle, varnames, data, lpoint): for names, array in zip(varnames, lpoint): data[names] = array data.tofile(filehandle) # Write binary if lpoint is None and len(self.buffer) == 0: logger.debug("There is no data to write into file.") try: # create initial data using the data structure. data = num.zeros(1, dtype=self.data_structure) with open(self.filename, mode="ab+") as fh: if lpoint is None: write_buffer = thin_buffer( self.buffer, self.buffer_thinning, ensure_last=True ) for lpoint, draw in write_buffer: lpoint2file(fh, self.varnames, data, lpoint) else: lpoint2file(fh, self.varnames, data, lpoint) except EnvironmentError as e: print("Error on write file: ", e) def _load_df(self): if not self.__data_structure: try: self.__data_structure = self.construct_data_structure() except json.decoder.JSONDecodeError: logger.warning( "File header of %s is corrupted!" " Resampling!" % self.filename ) self.corrupted_flag = True if self._df is None and not self.corrupted_flag: try: with open(self.filename, mode="rb") as file: # skip header. next(file) # read data self._df = num.fromfile(file, dtype=self.data_structure) except EOFError as e: print(e) self.corrupted_flag = True def get_values(self, varname, burn=0, thin=1): self._load_df() try: data = self._df[varname] shape = (self._df.shape[0],) + self.var_shapes[varname] vals = data.ravel().reshape(shape) return vals[burn::thin] except ValueError: raise ValueError( 'Did not find varname "%s" in sampling ' "results! Fixed?" % varname )
[docs] def point(self, idx): """ Get point of current chain with variables names as keys. Parameters ---------- idx : int Index of the nth step of the chain Returns ------- dictionary of point values """ idx = int(idx) self._load_df() pt = {} for varname in self.varnames: data = self._df[varname][idx] pt[varname] = data.reshape(self.var_shapes[varname]) return pt
backend_catalog = {"csv": TextChain, "bin": NumpyChain}
[docs] class TransDTextChain(object): """ Result Trace object for trans-d problems. Manages several TextChains one for each dimension. """ def __init__( self, name, model=None, value_vars=None, buffer_size=5000, progressbar=False ): self._straces = {} self.buffer_size = buffer_size self.progressbar = progressbar if value_vars is None: value_vars = model.unobserved_RVs transd, dims_idx = istransd(model) if transd: self.dims_idx else: raise ValueError("Model is not trans-d but TransD Chain initialized!") dimensions = model.unobserved_RVs[self.dims_idx] for k in range(dimensions.lower, dimensions.upper + 1): self._straces[k] = TextChain( dir_path=name, model=model, buffer_size=buffer_size, progressbar=progressbar, k=k, ) # init indexing chain self._index = TextChain( dir_path=name, value_vars=[], buffer_size=self.buffer_size, progressbar=self.progressbar, ) self._index.flat_names = {"draw__0": (1,), "k__0": (1,), "k_idx__0": (1,)} def setup(self, draws, chain): self.draws = num.zeros(1, dtype="int32") for k, trace in self._straces.items(): trace.setup(draws=draws, chain=k) self._index.setup(draws, chain=0) def write(self, lpoint, draw): self.draws[0] = draw ipoint = [self.draws, lpoint[self.dims_idx]] self._index.write(ipoint, draw) self._straces[lpoint[self.dims_idx]].write(lpoint, draw) def __len__(self): return int(self._index[-1]) def record_buffer(self): for trace in self._straces: trace.record_buffer() self._index.record_buffer()
[docs] def point(self, idx): """ Get point of current chain with variables names as keys. Parameters ---------- idx : int Index of the nth step of the chain Returns ------- dict : of point values """ ipoint = self._index.point(idx) return self._straces[ipoint["k"]].point(ipoint["k_idx"])
def get_values(self, varname): raise NotImplementedError()
class SampleStage(object): def __init__(self, base_dir, backend="csv"): self.base_dir = base_dir self.project_dir = os.path.dirname(base_dir) self.mode = os.path.basename(base_dir) self.backend = backend util.ensuredir(self.base_dir) def stage_path(self, stage): return os.path.join(self.base_dir, "stage_{}".format(stage)) def trans_stage_path(self, stage): return os.path.join(self.base_dir, "trans_stage_{}".format(stage)) def stage_number(self, stage_path): """ Inverse function of SampleStage.path """ return int(os.path.basename(stage_path).split("_")[-1]) def highest_sampled_stage(self): """ Return stage number of stage that has been sampled before the final stage. Returns ------- stage number : int """ return max(self.stage_number(s) for s in glob(self.stage_path("*"))) def get_stage_indexes(self, load_stage=None): """ Return indexes to all sampled stages. Parameters ---------- load_stage : int, optional if specified only return a list with this stage_index Returns ------- list of int, stage_index that have been sampled """ if load_stage is not None and isinstance(load_stage, int): return [load_stage] elif load_stage is not None and not isinstance(load_stage, int): raise ValueError('Requested stage_number has to be of type "int"') else: stage_number = self.highest_sampled_stage() if os.path.exists(self.smc_path(-1)): list_indexes = [i for i in range(-1, stage_number + 1)] else: list_indexes = [i for i in range(stage_number)] return list_indexes def smc_path(self, stage_number): """ Consistent naming for smc params. """ return os.path.join(self.stage_path(stage_number), sample_p_outname) def load_sampler_params(self, stage_number): """ Load saved parameters from last sampled stage. Parameters ---------- stage number : int of stage number or -1 for last stage """ if stage_number == -1: if not os.path.exists(self.smc_path(stage_number)): prev = self.highest_sampled_stage() else: prev = stage_number elif stage_number == -2: prev = stage_number + 1 else: prev = stage_number - 1 logger.info("Loading parameters from completed stage {}".format(prev)) sampler_state, updates = load_objects(self.smc_path(prev)) sampler_state["stage"] = stage_number return sampler_state, updates def dump_smc_params(self, stage_number, outlist): """ Save smc params to file. """ dump_objects(self.smc_path(stage_number), outlist) def clean_directory(self, stage, chains, rm_flag): """ Optionally remove directory for the stage. Does nothing if rm_flag is False. """ stage_path = self.stage_path(stage) if rm_flag: if os.path.exists(stage_path): logger.info("Removing previous sampling results ... %s" % stage_path) shutil.rmtree(stage_path) chains = None elif not os.path.exists(stage_path): chains = None return chains def load_multitrace(self, stage, chains=None, varnames=None): """ Load TextChain database. Parameters ---------- stage : int number of stage that should be loaded chains : list, optional of result chains to load, -1 is the summarized trace varnames : list of varnames in the model Returns ------- A :class:`pymc.backend.base.MultiTrace` instance """ dirname = self.stage_path(stage) return load_multitrace( dirname=dirname, chains=chains, varnames=varnames, backend=self.backend ) def recover_existing_results( self, stage, draws, step, buffer_thinning=1, varnames=None, update=None ): if stage > 0: prev = stage - 1 if update is not None: prev_stage_path = self.trans_stage_path(prev) else: prev_stage_path = self.stage_path(prev) logger.info( "Loading end points of last completed stage: " "%s" % prev_stage_path ) mtrace = load_multitrace( dirname=prev_stage_path, varnames=varnames, backend=self.backend ) ( step.population, step.array_population, step.likelihoods, ) = step.select_end_points(mtrace) stage_path = self.stage_path(stage) if os.path.exists(stage_path): # load incomplete stage results logger.info("Reloading existing results ...") mtrace = self.load_multitrace(stage, varnames=varnames) if len(mtrace.chains): # continue sampling if traces exist logger.info("Checking for corrupted files ...") return check_multitrace( mtrace, draws=draws, n_chains=step.n_chains, buffer_thinning=buffer_thinning, ) else: logger.info("Found no sampling results under %s " % stage_path) logger.info("Init new trace!") return None def istransd(varnames): dims = "dimensions" if dims in varnames: dims_idx = varnames.index(dims) return True, dims_idx else: logger.debug('Did not find "%s" random variable in model!' % dims) return False, None
[docs] def load_multitrace(dirname, varnames=[], chains=None, backend="csv"): """ Load TextChain database. Parameters ---------- dirname : str Name of directory with files (one per chain) varnames : list of strings with variable names chains : list optional Returns ------- A :class:`pymc.backend.base.MultiTrace` instance """ if not istransd(varnames)[0]: logger.info("Loading multitrace from %s" % dirname) if chains is None: files = glob(os.path.join(dirname, "chain-*.%s" % backend)) chains = [ int(os.path.splitext(os.path.basename(f))[0].replace("chain-", "")) for f in files ] final_chain = -1 if final_chain in chains: idx = chains.index(final_chain) files.pop(idx) chains.pop(idx) else: files = [ os.path.join(dirname, "chain-%i.%s" % (chain, backend)) for chain in chains ] for f in files: if not os.path.exists(f): raise IOError( "File %s does not exist! Please run:" ' "beat summarize <project_dir>"!' % f ) straces = [] for chain, f in zip(chains, files): strace = backend_catalog[backend](dirname) strace.chain = chain strace.filename = f straces.append(strace) return base.MultiTrace(straces) else: logger.info("Loading trans-d trace from %s" % dirname) raise NotImplementedError("Loading trans-d trace is not implemented!")
[docs] def check_multitrace(mtrace, draws, n_chains, buffer_thinning=1): """ Check multitrace for incomplete sampling and return indexes from chains that need to be resampled. Parameters ---------- mtrace : :class:`pymc.backend.base.MultiTrace` Multitrace object containing the sampling traces draws : int Number of steps (i.e. chain length for each Markov Chain) n_chains : int Number of Markov Chains Returns ------- list of indexes for chains that need to be resampled """ not_sampled_idx = [] # apply buffer thinning draws = int(num.ceil(draws / buffer_thinning)) for chain in range(n_chains): if chain in mtrace.chains: chain_len = len(mtrace._straces[chain]) if chain_len != draws: logger.warn( "Trace number %i incomplete: (%i / %i)" % (chain, chain_len, draws) ) mtrace._straces[chain].corrupted_flag = True else: not_sampled_idx.append(chain) flag_bool = [mtrace._straces[chain].corrupted_flag for chain in mtrace.chains] corrupted_idx = [i for i, x in enumerate(flag_bool) if x] return corrupted_idx + not_sampled_idx
[docs] def get_highest_sampled_stage(homedir, return_final=False): """ Return stage number of stage that has been sampled before the final stage. Parameters ---------- homedir : str Directory to the sampled stage results Returns ------- stage number : int """ stages = glob(os.path.join(homedir, "stage_*")) stagenumbers = [] for s in stages: stage_ending = os.path.splitext(s)[0].rsplit("_", 1)[1] try: stagenumbers.append(int(stage_ending)) except ValueError: logger.debug("string - That's the final stage!") if return_final: return stage_ending return max(stagenumbers)
[docs] def load_sampler_params(project_dir, stage_number, mode): """ Load saved parameters from given smc stage. Parameters ---------- project_dir : str absolute path to directory of BEAT project stage number : string of stage number or 'final' for last stage mode : str problem mode that has been solved ('geometry', 'static', 'kinematic') """ stage_path = os.path.join( project_dir, mode, "stage_%s" % stage_number, sample_p_outname ) return load_objects(stage_path)
[docs] def concatenate_traces(mtraces): """ Concatenate a List of MultiTraces with same chain indexes. """ base_traces = copy.deepcopy(mtraces) cat_trace = base_traces.pop(0) cat_dfs = [] for chain in cat_trace.chains: cat_trace._straces[chain]._load_df() cat_dfs.append(cat_trace._straces[chain].df) for mtrace in base_traces: for chain in cat_trace.chains: mtrace._straces[chain]._load_df() cat_dfs[chain] = cat_dfs[chain].append(mtrace._straces[chain].df) for chain in cat_trace.chains: cat_trace._straces[chain].df = cat_dfs[chain] return cat_trace
[docs] def extract_variables_from_df(dataframe): """ Extract random variables and their shapes from the pymc-pandas data-frame Parameters ---------- dataframe : :class:`pandas.DataFrame` Returns ------- flat_names : dict with variable-names and respective flat-name indexes to data-frame var_shapes : dict with variable names and shapes """ all_df_indexes = [str(flatvar) for flatvar in dataframe.columns] varnames = list(set([index.split("__")[0] for index in all_df_indexes])) flat_names = OrderedDict() var_shapes = OrderedDict() for varname in varnames: indexes = [] for index in all_df_indexes: if index.split("__")[0] == varname: indexes.append(index) flat_names[varname] = indexes var_shapes[varname] = _create_shape(indexes) return flat_names, var_shapes
[docs] def extract_bounds_from_summary(summary, varname, shape, roundto=None, alpha=0.01): """ Extract lower and upper bound of random variable. Returns ------- list of num.Ndarray """ def do_nothing(value): return value indexes = _create_flat_names_summary(varname, shape) lower_quant = "hdi_{0:g}%".format(100 * alpha / 2) upper_quant = "hdi_{0:g}%".format(100 * (1 - alpha / 2)) bounds = [] for quant in [lower_quant, upper_quant]: values = num.empty(shape, "float64") for i, idx in enumerate(indexes): if roundto is not None: adjust = 10.0**roundto if quant == lower_quant: operation = num.floor elif quant == upper_quant: operation = num.ceil else: adjust = 1.0 operation = do_nothing values[i] = operation(summary[quant][idx] * adjust) / adjust bounds.append(values) return bounds
def multitrace_to_inference_data(mtrace): idata_posterior_dict = {} for varname in mtrace.varnames: vals = num.atleast_2d(mtrace.get_values(varname).T) if num.isnan(vals).any(): logger.warning("Variable '%s' contains NaN values." % varname) size, draws = vals.shape if size > 1: vals = num.atleast_3d(vals).T idata_posterior_dict[varname] = vals idata = convert_to_inference_data(idata_posterior_dict) return idata