Source code for pymc.smc.sampling

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import logging
import time

from typing import Any

import numpy as np

from arviz import InferenceData
from rich.theme import Theme

import pymc

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.model import Model, modelcontext
from pymc.progress_bar import SMCProgressBarManager, default_progress_theme
from pymc.sampling.mcmc import setup_cores_blas_cores
from pymc.sampling.parallel import _cpu_count, _initialize_multiprocessing_context
from pymc.smc.kernels import IMH
from pymc.smc.parallel import ParallelSMCSampler
from pymc.stats.convergence import log_warnings, run_convergence_checks
from pymc.util import (
    RandomState,
    _get_seeds_per_chain,
)

logger = logging.getLogger(__name__)


[docs] def sample_smc( draws=2000, kernel=IMH, *, start=None, model=None, random_seed: RandomState = None, chains=None, cores=None, blas_cores: int | str | None = None, compute_convergence_checks=True, return_inferencedata=True, idata_kwargs=None, progressbar=True, progressbar_theme: Theme | None = default_progress_theme, compile_kwargs: dict | None = None, mp_ctx=None, **kernel_kwargs, ) -> InferenceData | MultiTrace: r""" Sequential Monte Carlo based sampling. Parameters ---------- draws : int, default 2000 The number of samples to draw from the posterior (i.e. last stage). And also the number of independent chains. Defaults to 2000. kernel : SMC_kernel, optional SMC kernel used. Defaults to :class:`pymc.smc.smc.IMH` (Independent Metropolis Hastings) start : dict or array of dict, optional Starting point in parameter space. It should be a list of dict with length `chains`. When None (default) the starting point is sampled from the prior distribution. model : Model (optional if in ``with`` context). random_seed : int, array_like of int, RandomState or numpy_Generator, optional Random seed(s) used by the sampling steps. If a list, tuple or array of ints is passed, each entry will be used to seed each chain. A ValueError will be raised if the length does not match the number of chains. chains : int, optional The number of chains to sample. Running independent chains is important for some convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever is larger. cores : int, default None The number of chains to run in parallel. If ``None``, set to the number of CPUs in the system. blas_cores: int or "auto" or None, default = "auto" The total number of threads blas and openmp functions should use during sampling. Setting it to "auto" will ensure that the total number of active blas threads is the same as the `cores` argument. If set to an integer, the sampler will try to use that total number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded down. If set to None, this will keep the default behavior of whatever blas implementation is used at runtime. compute_convergence_checks : bool, default True Whether to compute sampler statistics like ``R hat`` and ``effective_n``. Defaults to ``True``. return_inferencedata : bool, default True Whether to return the trace as an InferenceData (True) object or a MultiTrace (False). Defaults to ``True``. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data`. progressbar : bool, optional, default True Whether or not to display a progress bar in the command line. progressbar_theme : Theme, optional Custom theme for progress bar. Defaults to the standard PyMC progress bar theme. compile_kwargs: dict, optional Keyword arguments to pass to pytensor.function mp_ctx : multiprocessing.context or str, optional Multiprocessing context for parallel chains. Can be a context object or a string ("fork", "spawn", or "forkserver"). If None, defaults to "fork" on macOS ARM and "forkserver" on other macOS systems, and the system default elsewhere. **kernel_kwargs : dict, optional Keyword arguments passed to the SMC_kernel. The default IMH kernel takes the following keywords: threshold : float, default 0.5 Determines the change of beta from stage to stage, i.e. indirectly the number of stages, the higher the value of `threshold` the higher the number of stages. Defaults to 0.5. It should be between 0 and 1. correlation_threshold : float, default 0.01 The lower the value the higher the number of MCMC steps computed automatically. Defaults to 0.01. It should be between 0 and 1. Additional keyword arguments for other kernels should be checked in the respective docstrings. Notes ----- SMC works by moving through successive stages. At each stage the inverse temperature :math:`\beta` is increased a little bit (starting from 0 up to 1). When :math:`\beta` = 0 we have the prior distribution and when :math:`\beta = 1` we have the posterior distribution. So in more general terms, we are always computing samples from a tempered posterior that we can write as: .. math:: p(\theta \mid y)_{\beta} = p(y \mid \theta)^{\beta} p(\theta) A summary of the algorithm is: 1. Initialize :math:`\beta` at zero and stage at zero. 2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the tempered posterior is the prior). 3. Increase :math:`\beta` in order to make the effective sample size equal some predefined value (we use :math:`Nt`, where :math:`t` is 0.5 by default). 4. Compute a set of N importance weights W. The weights are computed as the ratio of the likelihoods of a sample at stage i+1 and stage i. 5. Obtain :math:`S_{w}` by re-sampling according to W. 6. Use W to compute the mean and covariance for the proposal distribution, a MvNormal. 7. Run N independent MCMC chains, starting each one from a different sample in :math:`S_{w}`. For the IMH kernel, the mean of the proposal distribution is the mean of the previous posterior stage and not the current point in parameter space. 8. The N chains are run until the autocorrelation with the samples from the previous stage stops decreasing given a certain threshold. 9. Repeat from step 3 until :math:`\beta \ge 1`. 10. The final result is a collection of N samples from the posterior. References ---------- .. [Minson2013] Minson, S. E., Simons, M., and Beck, J. L. (2013). "Bayesian inversion for finite fault earthquake source models I- Theory and algorithm." Geophysical Journal International, 2013, 194(3), pp.1701-1726. `link <https://gji.oxfordjournals.org/content/194/3/1701.full>`__ .. [Ching2007] Ching, J., and Chen, Y. (2007). "Transitional Markov Chain Monte Carlo Method for Bayesian Model Updating, Model Class Selection, and Model Averaging." J. Eng. Mech., 2007, 133(7), pp. 816-832. doi:10.1061/(ASCE)0733-9399(2007)133:7(816). `link <http://ascelibrary.org/doi/abs/10.1061/%28ASCE%290733-9399 %282007%29133:7%28816%29>`__ """ if cores is None: cores = _cpu_count() if chains is None: chains = max(2, cores) else: cores = min(chains, cores) if compile_kwargs is None: compile_kwargs = {} kernel_kwargs["compile_kwargs"] = compile_kwargs random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains) model = modelcontext(model) logger.info("Initializing SMC sampler...") mp_ctx = _initialize_multiprocessing_context(mp_ctx) joined_blas_limiter, cores, num_blas_cores_per_worker = setup_cores_blas_cores( blas_cores, chains, cores, mp_ctx ) t1 = time.time() rngs = [np.random.default_rng(seed) for seed in random_seed] with model: smc_kernel = kernel( draws=draws, start=None, model=model, random_seed=rngs[0].integers(2**30), **kernel_kwargs, ) # Prepare start points for each chain start_points: list[dict | None] if start is None: start_points = [None] * chains elif isinstance(start, dict): start_points = [start] * chains else: if len(start) != chains: raise ValueError(f"Number of start dicts must match number of chains ({chains})") start_points = start parallel = cores > 1 and chains > 1 traces = [] sample_stats = [] sample_settings = [] if parallel: logger.info( f"Sampling {chains} chain{'s' if chains > 1 else ''} " f"in {cores} job{'s' if cores > 1 else ''}" ) results = [] with joined_blas_limiter(): with ParallelSMCSampler( kernel=smc_kernel, chains=chains, cores=cores, rngs=rngs, start_points=start_points, progressbar=progressbar, progressbar_theme=progressbar_theme, mp_ctx=mp_ctx, blas_cores=num_blas_cores_per_worker, ) as sampler: for result in sampler: results.append(result) chain_results: list[list] = [[] for _ in range(chains)] for result in results: chain_results[result.chain].append(result) for chain_idx, chain_samples in enumerate(chain_results): if not chain_samples: raise RuntimeError( f"Chain {chain_idx} did not produce any results. " "This indicates a failure in parallel sampling." ) final_result = chain_samples[-1] trace = _build_trace_from_kernel_state( final_result.tempered_posterior, final_result.var_info, final_result.variables, chain_idx, model, ) traces.append(trace) sample_stats.append(final_result.sample_stats) sample_settings.append(final_result.sample_settings) else: logger.info( f"Sampling {chains} chain{'s' if chains > 1 else ''}{' sequentially' if chains > 1 else ''}" ) with joined_blas_limiter(): _sample_smc_sequentially( kernel=smc_kernel, chains=chains, rngs=rngs, start_points=start_points, model=model, progressbar=progressbar, progressbar_theme=progressbar_theme, traces=traces, sample_stats=sample_stats, sample_settings=sample_settings, ) trace = MultiTrace(traces) _t_sampling = time.time() - t1 _, idata = _save_sample_stats( sample_settings, sample_stats, chains, trace, return_inferencedata, _t_sampling, idata_kwargs, model, ) if compute_convergence_checks: if idata is None: idata = to_inference_data(trace, log_likelihood=False) warns = run_convergence_checks(idata, model) trace.report._add_warnings(warns) log_warnings(warns) if return_inferencedata: assert idata is not None return idata return trace
def _save_sample_stats( sample_settings, sample_stats, chains, trace: MultiTrace, return_inferencedata: bool, _t_sampling, idata_kwargs, model: Model, ) -> tuple[Any | None, InferenceData | None]: sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling sample_stats_dict = sample_stats[0] if chains > 1: # Collect the stat values from each chain in a single list for stat in sample_stats[0].keys(): value_list = [] for chain_sample_stats in sample_stats: value_list.append(chain_sample_stats[stat]) sample_stats_dict[stat] = value_list idata: InferenceData | None = None if not return_inferencedata: for stat, value in sample_stats_dict.items(): setattr(trace.report, stat, value) for stat, value in sample_settings_dict.items(): setattr(trace.report, stat, value) else: for stat, value in sample_stats_dict.items(): if chains > 1: # Different chains might have more iteration steps, leading to a # non-square `sample_stats` dataset, we cast as `object` to avoid # numpy ragged array deprecation warning sample_stats_dict[stat] = np.array(value, dtype=object) else: sample_stats_dict[stat] = np.array(value) sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, library=pymc, ) ikwargs: dict[str, Any] = {"model": model} if idata_kwargs is not None: ikwargs.update(idata_kwargs) idata = to_inference_data(trace, **ikwargs) idata = InferenceData(**idata, sample_stats=sample_stats) # type: ignore[arg-type] return sample_stats, idata def _build_trace_from_kernel_state( tempered_posterior: np.ndarray, var_info: dict, variables: list, chain: int, model: Model, ): """Build a trace from kernel state. This allows trace building to happen in the main process rather than workers. Parameters ---------- tempered_posterior : ndarray The final particle positions var_info : dict Dictionary of variable info {var.name: (shape, size)} variables : list List of model variables chain : int Chain index model : Model PyMC model for trace setup Returns ------- NDArray trace backend """ from pymc.backends.ndarray import NDArray from pymc.vartypes import discrete_types length_pos = len(tempered_posterior) varnames = [v.name for v in variables] strace = NDArray(name=model.name, model=model) strace.setup(length_pos, chain) for i in range(length_pos): value = [] size = 0 for var in variables: shape, new_size = var_info[var.name] var_samples = tempered_posterior[i][size : size + new_size] # Round discrete variable samples if var.dtype in discrete_types: var_samples = np.round(var_samples).astype(var.dtype) value.append(var_samples.reshape(shape)) size += new_size strace.record(point=dict(zip(varnames, value))) return strace def _sample_smc_sequentially( *, kernel, chains: int, rngs: list[np.random.Generator], start_points: list[dict | None], model: Model, progressbar: bool, progressbar_theme: Theme | None, traces: list, sample_stats: list, sample_settings: list, ): """Sample all SMC chains sequentially. Parameters ---------- kernel: SMC_KERNEL instance An initialized SMC kernel (with compiled functions) chains: int Total number of chains to sample rngs: list of random Generators A list of random number generators, one for each chain start_points: list Starting points for each chain model: Model The PyMC model progressbar: bool Whether to show progress bar progressbar_theme: Theme Progress bar theme traces: list List to append trace results to sample_stats: list List to append sample_stats to sample_settings: list List to append sample_settings to """ with SMCProgressBarManager( kernel=kernel, chains=chains, progressbar=progressbar, progressbar_theme=progressbar_theme, ) as progress_manager: for i in range(chains): kernel.initialize(start_points[i], rngs[i]) stage = 0 chain_sample_stats: dict[str, list] = {stat: [] for stat in kernel.stats_dtypes_shapes} while kernel.beta < 1: old_beta = kernel.beta kernel.update_beta_and_weights() progress_manager.update( chain_idx=i, stage=stage, beta=kernel.beta, old_beta=old_beta, is_last=False ) for stat, value in kernel.step().items(): chain_sample_stats[stat].append(value) stage += 1 progress_manager.update(chain_idx=i, stage=stage, beta=kernel.beta, is_last=True) trace = _build_trace_from_kernel_state( tempered_posterior=kernel.tempered_posterior, var_info=kernel.var_info, variables=kernel.variables, chain=i, model=model, ) traces.append(trace) sample_stats.append(chain_sample_stats) sample_settings.append(kernel.sample_settings())