Source code for pymc.smc.kernels

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
import abc
import warnings

from abc import ABC
from typing import TypeAlias

import numpy as np
import pytensor.tensor as pt

from pytensor import shared
from pytensor.graph.replace import clone_replace
from pytensor.link.jax import JAXLinker
from pytensor.tensor.random.type import RandomGeneratorType
from rich.progress import TextColumn
from rich.table import Column
from scipy.special import logsumexp
from scipy.stats import multivariate_normal

from pymc.blocking import DictToArrayBijection
from pymc.initial_point import make_initial_point_expression
from pymc.model import modelcontext
from pymc.pytensorf import (
    compile,
    floatX,
    join_nonshared_inputs,
    make_shared_replacements,
)
from pymc.sampling.forward import draw
from pymc.step_methods.metropolis import MultivariateNormalProposal
from pymc.vartypes import discrete_types

SMCStats: TypeAlias = dict[str, int | float]
SMCSettings: TypeAlias = dict[str, int | float]


[docs] class SMC_KERNEL(ABC): """Base class for the Sequential Monte Carlo kernels. To create a new SMC kernel you should subclass from this. Before sampling, the following methods are called once in order: initialize_population Choose initial population of SMC particles. Should return a dictionary with {var.name : numpy array of size (draws, var.size)}. Defaults to sampling from the prior distribution, except for parameters which have custom `initval`, in which case that value is used for all SMC particles. This method is only called if `start` is not specified. _initialize_kernel : default Creates initial population of particles in the variable `self.tempered_posterior` and populates the `self.var_info` dictionary with information about model variables shape and size as {var.name : (var.shape, var.size)}. The functions `self.prior_logp_func` and `self.likelihood_logp_func` are created in this step. These expect a 1D numpy array with the summed sizes of each raveled model variable (in the order specified in :meth:`pymc.Model.initial_point`). Finally, this method computes the log prior and log likelihood for the initial particles, and saves them in `self.prior_logp` and `self.likelihood_logp`. This method should not be modified. setup_kernel : optional May include any logic that should be performed before sampling starts. During each sampling stage the following methods are called in order: update_beta_and_weights : default The inverse temperature self.beta is updated based on the `self.likelihood_logp` and `threshold` parameter. The importance `self.weights` of each particle are computed from the old and newly selected inverse temperature. The iteration number stored in `self.iteration` is updated by this method. Finally the model `log_marginal_likelihood` of the tempered posterior is updated from these weights. resample : default The particles in `self.posterior` are sampled with replacement based on `self.weights`, and the used resampling indexes are saved in `self.resampling_indexes`. The arrays `self.prior_logp` and `self.likelihood_logp` are rearranged according to the order of the resampled particles. `self.tempered_posterior_logp` is computed from these and the current `self.beta`. tune : optional May include logic that should be performed before every mutation step. mutate : REQUIRED Mutate particles in `self.tempered_posterior`. This method is further responsible to update the `self.prior_logp`, `self.likelihod_logp` and `self.tempered_posterior_logp`, corresponding to each mutated particle. sample_stats : default Returns important sampling_stats at the end of each stage in a dictionary format. This will be saved in the final InferenceData object under `sample_stats`. Finally, at the end of sampling the following methods are called: _posterior_to_trace : default Convert final population of particles to a posterior trace object. This method should not be modified. sample_settings : default Returns important sample_settings at the end of sampling in a dictionary format. This will be saved in the final InferenceData object under `sample_stats`. """ stats_dtypes_shapes: dict[str, tuple[type, list]] = { "log_marginal_likelihood": (float, []), "beta": (float, []), }
[docs] def __init__( self, draws=2000, start=None, model=None, random_seed=None, threshold=0.5, compile_kwargs: dict | None = None, ): """ Initialize the SMC_kernel class. Parameters ---------- draws : int, default 2000 The number of samples to draw from the posterior (i.e. last stage). Also the number of independent chains. Defaults to 2000. start : dict, or array of dict, default None Starting point in parameter space. It should be a list of dict with length `chains`. When None (default) the starting point is sampled from the prior distribution, except for parameters with a custom `initval`, in which case that value is used. model : Model (optional if in ``with`` context). random_seed : int, array_like of int, RandomState or Generator, optional Value used to initialize the random number generator. threshold : float, default 0.5 Determines the change of beta from stage to stage, i.e.indirectly the number of stages, the higher the value of `threshold` the higher the number of stages. Defaults to 0.5. It should be between 0 and 1. compile_kwargs: dict, optional Keyword arguments passed to pytensor.function Attributes ---------- self.var_info : dict Dictionary that contains information about model variables shape and size. """ self.draws = draws self.start = start if threshold < 0 or threshold > 1: raise ValueError(f"Threshold value {threshold} must be between 0 and 1") self.threshold = threshold model = modelcontext(model) self.rng = np.random.default_rng(seed=random_seed) self.variables = model.value_vars self.var_info: dict[str, tuple] = {} self.tempered_posterior: np.ndarray self.prior_logp: np.ndarray | None = None self.likelihood_logp: np.ndarray | None = None self.tempered_posterior_logp: np.ndarray | None = None self.log_marginal_likelihood: float = 0.0 self.beta = 0.0 self.iteration = 0 self.resampling_indexes: np.ndarray | None = None self.weights = np.ones(self.draws) / self.draws initial_point = model.initial_point(random_seed=self.rng.integers(2**30)) for v in self.variables: self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size) shared = make_shared_replacements(initial_point, self.variables, model) compile_kwargs = compile_kwargs if compile_kwargs is not None else {} # If a model has no observed variables, the likelihood_logp will have unused inputs, which can be safely # ignored. compile_kwargs.update({"on_unused_input": "ignore"}) self.prior_logp_func = _logp_forw( initial_point, [model.varlogp], self.variables, shared, compile_kwargs ) self.likelihood_logp_func = _logp_forw( initial_point, [model.datalogp], self.variables, shared, compile_kwargs ) prior_expression = make_initial_point_expression( free_rvs=model.free_RVs, rvs_to_transforms=model.rvs_to_transforms, initval_strategies={ **model.rvs_to_initial_values, }, default_strategy="prior", return_transformed=True, ) self._prior_expression = prior_expression self._prior_var_names = [model.rvs_to_values[rv].name for rv in model.free_RVs]
[docs] def set_rng(self, rng: np.random.Generator): """ Copy compiled functions, updating their random number generators. This is necessary because these functions were compiled once at initialization, then pickled and sent to worker processes. Each worker needs its own RNG state to ensure independent sampling, so we replace the shared RNGs in the compiled functions with new ones created from the provided `rng`. This method copies the functions, so it is expensive and should only be called once per worker! """ def make_rng_swaps(fn, rng): shared_rngs = [ var for var in fn.get_shared() if isinstance(var.type, RandomGeneratorType) ] n_shared_rngs = len(shared_rngs) if n_shared_rngs > 0 and isinstance(fn.maker.linker, JAXLinker): raise NotImplementedError( f"JAX rngs cannot be replaced after compilation. {self}.set_rng will fail to " f"properly update random seeds between chains, resulting in non-independent " f"sampling." ) return { old_shared_rng: shared(new_rng, borrow=True) for old_shared_rng, new_rng in zip( shared_rngs, rng.spawn(n_shared_rngs), strict=True ) } self.rng = rng self.prior_logp_func = self.prior_logp_func.copy( swap=make_rng_swaps(self.prior_logp_func, self.rng) ) self.likelihood_logp_func = self.likelihood_logp_func.copy( swap=make_rng_swaps(self.likelihood_logp_func, self.rng) )
[docs] def initialize_population(self) -> dict[str, np.ndarray]: """Create an initial population from the prior distribution.""" with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=UserWarning, message="The effect of Potentials" ) prior_values = draw(self._prior_expression, draws=self.draws, random_seed=self.rng) dict_prior = dict(zip(self._prior_var_names, prior_values)) return dict_prior
def _initialize_kernel(self): """Initialize particles and compute their prior and likelihood logp. This method should not be overwritten. If needed, use `setup_kernel` instead. """ if self.start: init_rnd = self.start else: init_rnd = self.initialize_population() population = [] for i in range(self.draws): point = {v.name: init_rnd[v.name][i] for v in self.variables} population.append(DictToArrayBijection.map(point).data) self.tempered_posterior = np.array(floatX(population)) # Evaluate prior and likelihood for initial particles priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior] likelihoods = [self.likelihood_logp_func(sample) for sample in self.tempered_posterior] self.prior_logp = np.array(priors).squeeze() self.likelihood_logp = np.array(likelihoods).squeeze()
[docs] def setup_kernel(self): """Perform setup logic once before sampling starts.""" pass
[docs] def update_beta_and_weights(self): """Calculate the next inverse temperature (beta). The importance weights based on two successive tempered likelihoods (i.e. two successive values of beta) and updates the marginal likelihood estimate. ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4 """ self.iteration += 1 low_beta = old_beta = self.beta up_beta = 2.0 rN = int(len(self.likelihood_logp) * self.threshold) while up_beta - low_beta > 1e-6: new_beta = (low_beta + up_beta) / 2.0 log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) ESS = int(np.exp(-logsumexp(log_weights * 2))) if ESS == rN: break elif ESS < rN: up_beta = new_beta else: low_beta = new_beta if new_beta >= 1: new_beta = 1 log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) self.beta = new_beta self.weights = np.exp(log_weights) # We normalize again to correct for small numerical errors that might build up self.weights /= self.weights.sum() self.log_marginal_likelihood += logsumexp(log_weights_un) - np.log(self.draws)
[docs] def resample(self): """Resample particles based on importance weights.""" self.resampling_indexes = systematic_resampling(self.weights, self.rng) self.tempered_posterior = self.tempered_posterior[self.resampling_indexes] self.prior_logp = self.prior_logp[self.resampling_indexes] self.likelihood_logp = self.likelihood_logp[self.resampling_indexes] self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta
[docs] def tune(self): """Tuning logic performed before every mutation step.""" pass
[docs] @abc.abstractmethod def mutate(self): """Apply kernel-specific perturbation to the particles once per stage.""" pass
[docs] @abc.abstractmethod def sample_stats(self) -> SMCStats: """Stats to be saved at the end of each stage. These stats will be saved under `sample_stats` in the final InferenceData object. """ pass
[docs] def step(self) -> SMCStats: """Perform a single SMC stage: resample, tune, and mutate.""" self.resample() self.tune() self.mutate() return self.sample_stats()
[docs] @abc.abstractmethod def sample_settings(self) -> SMCSettings: """SMC_kernel settings to be saved once at the end of sampling. These stats will be saved under `sample_stats` in the final InferenceData object. """ pass
def _reset_state(self): """Reset the sampling state for a new run.""" self.tempered_posterior = np.empty(0) self.prior_logp = None self.likelihood_logp = None self.tempered_posterior_logp = None self.log_marginal_likelihood = 0.0 self.beta = 0.0 self.iteration = 0 self.resampling_indexes = None self.weights = np.ones(self.draws) / self.draws
[docs] def initialize(self, start: dict | None, rng: np.random.Generator) -> None: """Initialize the kernel for sampling. Parameters ---------- start : dict or None Starting point in parameter space, or None to sample from prior. rng : np.random.Generator Random number generator for this chain. """ self.start = start self.rng = rng self._reset_state() self.set_rng(rng) self._initialize_kernel() self.setup_kernel()
@staticmethod def _progressbar_config(n_chains=1): """Configure progress bar columns for SMC sampling. Returns columns to display and initial stats values. """ columns = [ TextColumn("{task.fields[beta]:.4f}", table_column=Column("Beta", ratio=1)), ] stats = { "beta": [0.0] * n_chains, } return columns, stats @staticmethod def _make_progressbar_update_functions(): """Create functions to update progress bar statistics.""" def update_stats(stats): return { "beta": stats.get("beta", 0.0), } return (update_stats,)
[docs] class IMH(SMC_KERNEL): """Independent Metropolis-Hastings SMC_kernel.""" stats_dtypes_shapes: dict[str, tuple[type, list]] = { "log_marginal_likelihood": (float, []), "beta": (float, []), "accept_rate": (float, []), }
[docs] def __init__(self, *args, correlation_threshold=0.01, **kwargs): """ Create the Independent Metropolis-Hastings SMC kernel object. Parameters ---------- correlation_threshold : float, default 0.01 The lower the value, the higher the number of IMH steps computed automatically. Defaults to 0.01. It should be between 0 and 1. **kwargs : dict, optional Keyword arguments passed to the SMC_kernel. Refer to SMC_kernel documentation for a list of all possible arguments. """ super().__init__(*args, **kwargs) self.correlation_threshold = correlation_threshold self.proposal_dist = None self.acc_rate = None
[docs] def tune(self): # Update MVNormal proposal based on the mean and covariance of the # tempered posterior. cov = np.cov(self.tempered_posterior, ddof=0, rowvar=0) cov = np.atleast_2d(cov) cov += 1e-6 * np.eye(cov.shape[0]) if np.isnan(cov).any() or np.isinf(cov).any(): raise ValueError('Sample covariances not valid! Likely "draws" is too small!') mean = np.average(self.tempered_posterior, axis=0) self.proposal_dist = multivariate_normal(mean, cov)
[docs] def mutate(self): """Independent Metropolis-Hastings perturbation.""" self.n_steps = 1 old_corr = 2 corr = Pearson(self.tempered_posterior) ac_ = [] while True: log_R = np.log(self.rng.random(self.draws)) # The proposal is independent from the current point. # We have to take that into account to compute the Metropolis-Hastings acceptance # We first compute the logp of proposing a transition to the current points. # This variable is updated at the end of the loop with the entries from the accepted # transitions, which is equivalent to recomputing it in every iteration of the loop. proposal = floatX(self.proposal_dist.rvs(size=self.draws, random_state=self.rng)) proposal = proposal.reshape(len(proposal), -1) # To do that we compute the logp of moving to a new point forward_logp = self.proposal_dist.logpdf(proposal) # And to going back from that new point backward_logp = self.proposal_dist.logpdf(self.tempered_posterior) ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) proposal_logp = pl + ll * self.beta accepted = log_R < ( (proposal_logp + backward_logp) - (self.tempered_posterior_logp + forward_logp) ) self.tempered_posterior[accepted] = proposal[accepted] self.tempered_posterior_logp[accepted] = proposal_logp[accepted] self.prior_logp[accepted] = pl[accepted] self.likelihood_logp[accepted] = ll[accepted] ac_.append(accepted) self.n_steps += 1 pearson_r = corr.get(self.tempered_posterior) if np.mean((old_corr - pearson_r) > self.correlation_threshold) > 0.9: old_corr = pearson_r else: break self.acc_rate = np.mean(ac_)
[docs] def sample_stats(self) -> SMCStats: return { "log_marginal_likelihood": self.log_marginal_likelihood if self.beta == 1 else np.nan, "beta": self.beta, "accept_rate": self.acc_rate, }
[docs] def sample_settings(self) -> SMCSettings: return { "_n_draws": self.draws, "threshold": self.threshold, "_n_tune": self.n_steps, "correlation_threshold": self.correlation_threshold, }
class Pearson: def __init__(self, a): self.l = a.shape[0] self.am = a - np.sum(a, axis=0) / self.l self.aa = np.sum(self.am**2, axis=0) ** 0.5 def get(self, b): bm = b - np.sum(b, axis=0) / self.l bb = np.sum(bm**2, axis=0) ** 0.5 ab = np.sum(self.am * bm, axis=0) return np.abs(ab / (self.aa * bb))
[docs] class MH(SMC_KERNEL): """Metropolis-Hastings SMC_kernel.""" stats_dtypes_shapes: dict[str, tuple[type, list]] = { "log_marginal_likelihood": (float, []), "beta": (float, []), "mean_accept_rate": (float, []), "mean_proposal_scale": (float, []), }
[docs] def __init__(self, *args, correlation_threshold=0.01, **kwargs): """ Create a Metropolis-Hastings SMC kernel. Parameters ---------- correlation_threshold : float, default 0.01 The lower the value, the higher the number of MH steps computed automatically. Defaults to 0.01. It should be between 0 and 1. **kwargs : dict, optional Keyword arguments passed to the SMC_kernel. Refer to SMC_kernel documentation for a list of all possible arguments. """ super().__init__(*args, **kwargs) self.correlation_threshold = correlation_threshold self.proposal_dist = None self.proposal_scales = None self.chain_acc_rate = None
[docs] def setup_kernel(self): """Proposal dist is just a Multivariate Normal with unit identity covariance. Dimension specific scaling is provided by `self.proposal_scales` and set in `self.tune()`. """ ndim = self.tempered_posterior.shape[1] self.proposal_scales = np.full(self.draws, min(1, 2.38**2 / ndim))
[docs] def resample(self): super().resample() if self.iteration > 1: self.proposal_scales = self.proposal_scales[self.resampling_indexes] self.chain_acc_rate = self.chain_acc_rate[self.resampling_indexes]
[docs] def tune(self): """Update proposal scales for each particle dimension and update number of MH steps.""" if self.iteration > 1: # Rescale based on distance to 0.234 acceptance rate chain_scales = np.exp(np.log(self.proposal_scales) + (self.chain_acc_rate - 0.234)) # Interpolate between individual and population scales self.proposal_scales = 0.5 * (chain_scales + chain_scales.mean()) # Update MVNormal proposal based on the covariance of the tempered posterior. cov = np.cov(self.tempered_posterior, ddof=0, rowvar=0) cov = np.atleast_2d(cov) cov += 1e-6 * np.eye(cov.shape[0]) if np.isnan(cov).any() or np.isinf(cov).any(): raise ValueError('Sample covariances not valid! Likely "draws" is too small!') self.proposal_dist = MultivariateNormalProposal(cov)
[docs] def mutate(self): """Metropolis-Hastings perturbation.""" self.n_steps = 1 old_corr = 2 corr = Pearson(self.tempered_posterior) ac_ = [] while True: log_R = np.log(self.rng.random(self.draws)) proposal = floatX( self.tempered_posterior + self.proposal_dist(num_draws=self.draws, rng=self.rng) * self.proposal_scales[:, None] ) ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) proposal_logp = pl + ll * self.beta accepted = log_R < (proposal_logp - self.tempered_posterior_logp) self.tempered_posterior[accepted] = proposal[accepted] self.prior_logp[accepted] = pl[accepted] self.likelihood_logp[accepted] = ll[accepted] self.tempered_posterior_logp[accepted] = proposal_logp[accepted] ac_.append(accepted) self.n_steps += 1 pearson_r = corr.get(self.tempered_posterior) if np.mean((old_corr - pearson_r) > self.correlation_threshold) > 0.9: old_corr = pearson_r else: break self.chain_acc_rate = np.mean(ac_, axis=0)
[docs] def sample_stats(self) -> SMCStats: return { "log_marginal_likelihood": self.log_marginal_likelihood if self.beta == 1 else np.nan, "beta": self.beta, "mean_accept_rate": self.chain_acc_rate.mean(), "mean_proposal_scale": self.proposal_scales.mean(), }
[docs] def sample_settings(self) -> SMCSettings: return { "_n_draws": self.draws, "threshold": self.threshold, "_n_tune": self.n_steps, "correlation_threshold": self.correlation_threshold, }
def systematic_resampling(weights, rng): """ Systematic resampling. Parameters ---------- weights : The weights should be probabilities and the total sum should be 1. Returns ------- new_indices: array A vector of indices in the interval 0, ..., len(normalized_weights) """ lnw = len(weights) arange = np.arange(lnw) uniform = (rng.random(1) + arange) / lnw idx = 0 weight_accu = weights[0] new_indices = np.empty(lnw, dtype=int) for i in arange: while uniform[i] > weight_accu: idx += 1 weight_accu += weights[idx] new_indices[i] = idx return new_indices def _logp_forw(point, out_vars, in_vars, shared, compile_kwargs=None): """Compile PyTensor function of the model and the input and output variables. Parameters ---------- out_vars : list Containing Distribution for the output variables in_vars : list Containing Distribution for the input variables shared : list Containing TensorVariable for depended shared data compile_kwargs: dict, optional Additional keyword arguments passed to pytensor.function """ if compile_kwargs is None: compile_kwargs = {} # Replace integer inputs with rounded float inputs if any(var.dtype in discrete_types for var in in_vars): replace_int_input = {} new_in_vars = [] for in_var in in_vars: if in_var.dtype in discrete_types: float_var = pt.TensorType("floatX", in_var.type.shape)(in_var.name) new_in_vars.append(float_var) replace_int_input[in_var] = pt.round(float_var).astype(in_var.dtype) else: new_in_vars.append(in_var) out_vars = clone_replace(out_vars, replace_int_input, rebuild_strict=False) in_vars = new_in_vars out_list, inarray0 = join_nonshared_inputs( point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared ) f = compile([inarray0], out_list[0], **compile_kwargs) f.trust_input = True return f