Source code for episia.models.sensitivity

"""
models/sensitivity.py - Monte Carlo sensitivity analysis for compartmental models.

Samples parameter distributions, runs N model instances, and aggregates
trajectories into percentile envelopes + summary statistics.

Public classes
--------------
    SensitivityAnalysis    configure distributions and run sampling
    SensitivityResult      percentile envelopes, summaries, plots

Supported distributions
-----------------------
    ("uniform",   low, high)
    ("normal",    mean, std)
    ("lognormal", mean, sigma)        # mean/std of the underlying normal
    ("triangular", low, mode, high)
    ("beta_dist",  alpha, beta)       # scipy beta, scaled to [0,1]
    ("fixed",     value)              # pin a parameter, no sampling

Performance
-----------
    Uses concurrent.futures.ProcessPoolExecutor by default.
    Falls back to sequential execution if n_jobs=1 or pickling fails.
"""

from __future__ import annotations

import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type

import numpy as np


def _cprint(text: str, color: tuple) -> None:
    """Print text in RGB colour."""
    r, g, b = color
    print(f"\033[38;2;{r};{g};{b}m{text}\033[0m")


# ─────────────────────────────────────────────────────────────────────────────
# Sampler
# ─────────────────────────────────────────────────────────────────────────────

def _sample_param(spec: Tuple, rng: np.random.Generator) -> float:
    """Draw one sample from a distribution spec."""
    kind = spec[0]

    if kind == "fixed":
        return float(spec[1])

    elif kind == "uniform":
        low, high = float(spec[1]), float(spec[2])
        return float(rng.uniform(low, high))

    elif kind == "normal":
        mean, std = float(spec[1]), float(spec[2])
        return float(rng.normal(mean, std))

    elif kind == "lognormal":
        mean, sigma = float(spec[1]), float(spec[2])
        return float(rng.lognormal(mean, sigma))

    elif kind == "triangular":
        low, mode, high = float(spec[1]), float(spec[2]), float(spec[3])
        return float(rng.triangular(low, mode, high))

    elif kind == "beta_dist":
        alpha, beta = float(spec[1]), float(spec[2])
        from scipy.stats import beta as scipy_beta
        return float(scipy_beta.rvs(alpha, beta, random_state=int(rng.integers(0, 2**31))))

    else:
        raise ValueError(
            f"Unknown distribution '{kind}'. "
            f"Choose: uniform, normal, lognormal, triangular, beta_dist, fixed."
        )


def _draw_samples(
    distributions: Dict[str, Tuple],
    n_samples: int,
    seed: Optional[int],
) -> List[Dict[str, float]]:
    """Draw n_samples parameter dicts from the given distributions."""
    rng = np.random.default_rng(seed)
    samples = []
    for _ in range(n_samples):
        draw = {k: _sample_param(spec, rng) for k, spec in distributions.items()}
        samples.append(draw)
    return samples


# ─────────────────────────────────────────────────────────────────────────────
# Worker (module-level for pickling with ProcessPoolExecutor)
# ─────────────────────────────────────────────────────────────────────────────

def _run_one(args):
    """
    Run a single model instance.

    Module-level for multiprocessing pickling on Windows (spawn).
    Robust to t_span arriving as list (JSON serialization).
    """
    model_class, param_class, fixed_params, sample, t_eval_len = args
    try:
        all_params = {**fixed_params, **sample}

        # t_span peut arriver en liste depuis sérialisation JSON
        if "t_span" in all_params and not isinstance(all_params["t_span"], tuple):
            all_params["t_span"] = tuple(all_params["t_span"])

        params = param_class(**all_params)
        model  = model_class(params)

        t_span = params.t_span
        t_eval = np.linspace(float(t_span[0]), float(t_span[1]), int(t_eval_len))
        result = model.run(t_eval=t_eval)

        return {
            "t":             np.asarray(result.t, dtype=float),
            "compartments":  {k: np.asarray(v, dtype=float)
                              for k, v in result.compartments.items()},
            "r0":            float(result.r0) if result.r0 is not None else None,
            "peak_infected": float(result.peak_infected) if result.peak_infected is not None else None,
            "peak_time":     float(result.peak_time) if result.peak_time is not None else None,
            "final_size":    float(result.final_size) if result.final_size is not None else None,
            "params":        sample,
        }
    except Exception as e:
        import traceback
        return {
            "error":  f"{type(e).__name__}: {e}",
            "detail": traceback.format_exc(),
            "params": sample,
        }


# ─────────────────────────────────────────────────────────────────────────────
# SensitivityResult
# ─────────────────────────────────────────────────────────────────────────────

[docs] @dataclass class SensitivityResult: """ Aggregated result of a Monte Carlo sensitivity analysis. Attributes: t: Common time array. envelopes: Dict compartment → {'p5','p25','p50','p75','p95'} arrays. metrics: DataFrame-ready summary of scalar metrics across runs. n_samples: Number of successful runs. n_failed: Number of failed runs. param_samples: List of sampled parameter dicts. compartment_names: Compartments present in results. """ t: np.ndarray envelopes: Dict[str, Dict[str, np.ndarray]] metrics: Dict[str, np.ndarray] # r0, peak_infected, … n_samples: int n_failed: int param_samples: List[Dict[str, float]] compartment_names: List[str] # ── Summary ──────────────────────────────────────────────────────────────
[docs] def summary(self) -> Dict[str, Any]: """ Return a dict of summary statistics for each scalar metric. Returns: Dict with keys like 'r0_median', 'r0_p5', 'peak_infected_p95', etc. """ out: Dict[str, Any] = { "n_samples": self.n_samples, "n_failed": self.n_failed, } for metric, values in self.metrics.items(): if len(values) == 0: continue out[f"{metric}_median"] = float(np.median(values)) out[f"{metric}_mean"] = float(np.mean(values)) out[f"{metric}_p5"] = float(np.percentile(values, 5)) out[f"{metric}_p95"] = float(np.percentile(values, 95)) out[f"{metric}_std"] = float(np.std(values)) return out
[docs] def to_dataframe(self): """ Return a pandas DataFrame with one row per successful run. Columns: sampled parameters + r0, peak_infected, peak_time, final_size. """ try: import pandas as pd except ImportError: raise ImportError("pandas is required. pip install pandas") rows = [] n = self.n_samples for i in range(n): row = {k: v[i] if len(v) == n else np.nan for k, v in self.metrics.items()} if i < len(self.param_samples): row.update(self.param_samples[i]) rows.append(row) return pd.DataFrame(rows)
# ── Plots ─────────────────────────────────────────────────────────────────
[docs] def plot( self, compartment: str = "I", show_samples: bool = False, n_sample_traces: int = 50, backend: str = "plotly", theme: str = "scientific", title: Optional[str] = None, ) -> Any: """ Plot percentile envelope for a compartment. Args: compartment: Compartment name (default 'I'). show_samples: Overlay individual sample trajectories. n_sample_traces: How many individual traces to show (max). backend: 'plotly' or 'matplotlib'. theme: Theme name. title: Figure title (auto if None). Returns: Figure object. """ if compartment not in self.envelopes: raise ValueError( f"Compartment '{compartment}' not in results. " f"Available: {self.compartment_names}" ) env = self.envelopes[compartment] t = self.t title = title or ( f"Sensitivity Analysis {compartment} " f"(n={self.n_samples})" ) from ..viz.themes.registry import get_palette pal = get_palette(theme) col = pal[0] if backend == "plotly": return self._plot_plotly( t, env, compartment, col, title, show_samples, n_sample_traces, theme, ) else: return self._plot_mpl( t, env, compartment, col, title, show_samples, n_sample_traces, theme, )
def _plot_plotly(self, t, env, comp, col, title, show_samples, n_traces, theme): import plotly.graph_objects as go from ..viz.plotters.plotly_plotter import _layout, _FONT_COLOR from ..viz.plotters import PlotConfig def rgba(hex_col, alpha): h = hex_col.lstrip("#") r,g,b = int(h[:2],16), int(h[2:4],16), int(h[4:],16) return f"rgba({r},{g},{b},{alpha})" fig = go.Figure() # p5–p95 band fig.add_trace(go.Scatter( x=list(t) + list(t[::-1]), y=list(env["p95"]) + list(env["p5"][::-1]), fill="toself", fillcolor=rgba(col, 0.12), line=dict(color="rgba(0,0,0,0)"), name="5e–95e percentile", hoverinfo="skip", )) # p25–p75 band fig.add_trace(go.Scatter( x=list(t) + list(t[::-1]), y=list(env["p75"]) + list(env["p25"][::-1]), fill="toself", fillcolor=rgba(col, 0.25), line=dict(color="rgba(0,0,0,0)"), name="25e–75e percentile", hoverinfo="skip", )) # Median fig.add_trace(go.Scatter( x=list(t), y=list(env["p50"]), mode="lines", line=dict(color=col, width=2.5), name="Médiane", )) # Individual traces if show_samples and "all_trajectories" in self.envelopes.get(comp + "_raw", {}): pass # stored separately if needed config = PlotConfig( title=title, theme=theme, xlabel="Jours", ylabel=f"Individus ({comp})", ) from ..viz.plotters.plotly_plotter import _layout fig.update_layout(_layout(config)) return fig def _plot_mpl(self, t, env, comp, col, title, show_samples, n_traces, theme): import matplotlib.pyplot as plt from ..viz.themes.registry import apply_mpl_theme apply_mpl_theme(theme) fig, ax = plt.subplots(figsize=(11, 5), facecolor="white") ax.fill_between(t, env["p5"], env["p95"], alpha=0.12, color=col, label="5e–95e percentile") ax.fill_between(t, env["p25"], env["p75"], alpha=0.28, color=col, label="25e–75e percentile") ax.plot(t, env["p50"], color=col, linewidth=2.5, label="Médiane") ax.set_xlabel("Jours", fontsize=11) ax.set_ylabel(f"Individus ({comp})", fontsize=11) ax.set_title(title, fontsize=13, fontweight="bold") ax.legend(fontsize=10) ax.grid(True, alpha=0.3, linestyle="--") ax.spines[["top", "right"]].set_visible(False) fig.tight_layout() return fig
[docs] def plot_metric_distribution( self, metric: str = "r0", backend: str = "plotly", theme: str = "scientific", ) -> Any: """ Histogram of a scalar metric across all runs. Args: metric: One of 'r0', 'peak_infected', 'peak_time', 'final_size'. backend: 'plotly' or 'matplotlib'. theme: Theme name. Returns: Figure object. """ if metric not in self.metrics: raise ValueError( f"Metric '{metric}' not found. " f"Available: {list(self.metrics.keys())}" ) values = self.metrics[metric] label_map = { "r0": "R₀", "peak_infected": "Pic infectieux", "peak_time": "Jour du pic", "final_size": "Taille finale (fraction)", } label = label_map.get(metric, metric) median = float(np.median(values)) p5 = float(np.percentile(values, 5)) p95 = float(np.percentile(values, 95)) title = f"Distribution {label} (médiane={median:.3f})" from ..viz.themes.registry import get_palette pal = get_palette(theme) if backend == "plotly": import plotly.graph_objects as go from ..viz.plotters import PlotConfig from ..viz.plotters.plotly_plotter import _layout fig = go.Figure() fig.add_trace(go.Histogram( x=list(values), nbinsx=40, marker_color=pal[0], opacity=0.75, name=label, )) fig.add_vline(x=median, line=dict(color=pal[1], width=2, dash="dash"), annotation_text=f"Médiane {median:.3f}", annotation_position="top right") fig.add_vrect(x0=p5, x1=p95, fillcolor=pal[0], opacity=0.08, line_width=0, annotation_text="5e–95e", annotation_position="top left") config = PlotConfig(title=title, theme=theme, xlabel=label, ylabel="Fréquence") fig.update_layout(_layout(config)) return fig else: import matplotlib.pyplot as plt from ..viz.themes.registry import apply_mpl_theme apply_mpl_theme(theme) fig, ax = plt.subplots(figsize=(8, 4), facecolor="white") ax.hist(values, bins=40, color=pal[0], alpha=0.75, edgecolor="white") ax.axvline(median, color=pal[1], linewidth=2, linestyle="--", label=f"Médiane {median:.3f}") ax.axvspan(p5, p95, alpha=0.10, color=pal[0], label="5e–95e percentile") ax.set_xlabel(label, fontsize=11) ax.set_ylabel("Fréquence", fontsize=11) ax.set_title(title, fontsize=13, fontweight="bold") ax.legend(fontsize=10) ax.spines[["top", "right"]].set_visible(False) fig.tight_layout() return fig
[docs] def __repr__(self) -> str: s = self.summary() lines = [ f"SensitivityResult n={self.n_samples} ({self.n_failed} failed)", ] for m in ["r0", "peak_infected", "final_size"]: if f"{m}_median" in s: lines.append( f" {m:20s}: median={s[f'{m}_median']:.3f} " f"[{s[f'{m}_p5']:.3f}, {s[f'{m}_p95']:.3f}]" ) return "\n".join(lines)
# ───────────────────────────────────────────────────────────────────────────── # SensitivityAnalysis # ─────────────────────────────────────────────────────────────────────────────
[docs] class SensitivityAnalysis: """ Monte Carlo sensitivity analysis for compartmental epidemic models. Example:: from episia.models.sensitivity import SensitivityAnalysis from episia.models import SEIRModel from episia.models.parameters import SEIRParameters sa = SensitivityAnalysis( model_class=SEIRModel, param_class=SEIRParameters, fixed=dict(N=1_000_000, I0=10, E0=50, t_span=(0, 365)), distributions={ 'beta': ('uniform', 0.25, 0.50), 'sigma': ('normal', 1/5.2, 0.02), 'gamma': ('uniform', 1/21, 1/7), }, n_samples=500, seed=42, ) result = sa.run() print(result) result.plot(compartment='I').show() result.plot_metric_distribution('r0').show() result.to_dataframe() """
[docs] def __init__( self, model_class, param_class, fixed: Dict[str, Any], distributions: Dict[str, Tuple], n_samples: int = 200, seed: Optional[int] = 42, n_jobs: int = 1, t_eval_points: int = 500, ): """ Args: model_class: CompartmentalModel subclass. param_class: Matching parameters class. fixed: Parameters held constant across all runs. distributions: Parameters to sample; values are distribution specs. E.g. {'beta': ('uniform', 0.2, 0.5)}. n_samples: Number of Monte Carlo draws. seed: Random seed for reproducibility. n_jobs: Parallel workers (1 = sequential, -1 = all CPUs). t_eval_points: Number of time points per trajectory. """ self.model_class = model_class self.param_class = param_class self.fixed = fixed self.distributions = distributions self.n_samples = n_samples self.seed = seed self.n_jobs = n_jobs self.t_eval_points = t_eval_points
# ── run ──────────────────────────────────────────────────────────────────
[docs] def run(self, verbose: bool = True) -> SensitivityResult: """ Draw samples, run all models, and return aggregated SensitivityResult. Args: verbose: Print progress summary. Returns: SensitivityResult. """ if verbose: print(f"Sampling {self.n_samples} parameter sets…") samples = _draw_samples(self.distributions, self.n_samples, self.seed) # Validate one sample first to catch config errors early self._validate_one(samples[0]) raw_results = self._execute_with_progress(samples, verbose) n_ok = sum(1 for r in raw_results if "error" not in r) n_fail = len(raw_results) - n_ok if n_fail > 0 and verbose: errors = [r for r in raw_results if "error" in r][:3] for r in errors: _cprint(f" [erreur] {r['error']}", (255, 80, 80)) if "detail" in r and n_ok == 0: print(r["detail"][:600]) return self._aggregate(raw_results, samples)
# ── internal ───────────────────────────────────────────────────────────── def _validate_one(self, sample: Dict[str, float]) -> None: """Instantiate one model to catch parameter errors before full run.""" try: all_params = {**self.fixed, **sample} self.param_class(**all_params) except Exception as e: raise ValueError( f"Parameter validation failed for sample {sample}: {e}" ) # ── Progress helpers ───────────────────────────────────────────────────── def _execute_with_progress( self, samples: List[Dict], verbose: bool ) -> List[Dict]: """Run all models with a coloured gradient progress bar.""" n = len(samples) if not verbose: return self._execute(samples) # Gradient colours teal → violet → rose _GRAD = [ (0, 210, 190), (0, 180, 255), (100, 120, 255), (200, 60, 220), (240, 80, 160), ] def _lerp(a, b, t): return int(a + (b - a) * t) def _grad_color(pos, total): t = pos / max(total - 1, 1) n_stops = len(_GRAD) - 1 i = min(int(t * n_stops), n_stops - 1) lt = t * n_stops - i r = _lerp(_GRAD[i][0], _GRAD[i+1][0], lt) g = _lerp(_GRAD[i][1], _GRAD[i+1][1], lt) b = _lerp(_GRAD[i][2], _GRAD[i+1][2], lt) return r, g, b def _rgb(r, g, b, text): return f"\033[38;2;{r};{g};{b}m{text}\033[0m" def _bold(text): return f"\033[1m{text}\033[0m" bar_width = 36 label = "Episia · Monte Carlo" def _draw(done, total, n_ok, n_fail): frac = done / max(total, 1) filled = int(frac * bar_width) bar = "" for i in range(bar_width): r, g, b = _grad_color(i, bar_width) char = "█" if i < filled else "░" bar += _rgb(r, g, b, char) r0, g0, b0 = _grad_color(filled, bar_width) pct = _rgb(r0, g0, b0, f"{frac*100:5.1f}%") stat = f" {done}/{total}" if n_fail: stat += f" \033[38;2;255;80;80m✗ {n_fail}\033[0m" line = f" {_bold(label)} {bar} {pct}{stat}" print(f"\r{line}", end="", flush=True) results = [] n_ok = n_fail = 0 args_list = [ (self.model_class, self.param_class, self.fixed, s, self.t_eval_points) for s in samples ] print() # newline before bar for i, args in enumerate(args_list): r = _run_one(args) results.append(r) if "error" in r: n_fail += 1 else: n_ok += 1 _draw(i + 1, n, n_ok, n_fail) r0, g0, b0 = _grad_color(bar_width - 1, bar_width) done_text = _rgb(r0, g0, b0, "✓ done") print(f" {done_text} {n_ok}/{n} OK\n", flush=True) return results def _execute(self, samples: List[Dict]) -> List[Dict]: """Run all models sequential (n_jobs=1) or parallel.""" args_list = [ (self.model_class, self.param_class, self.fixed, s, self.t_eval_points) for s in samples ] if self.n_jobs == 1: return [_run_one(a) for a in args_list] # Parallel fallback to sequential on error try: workers = ( self.n_jobs if self.n_jobs > 0 else __import__("os").cpu_count() ) results = [None] * len(args_list) with ProcessPoolExecutor(max_workers=workers) as ex: future_to_idx = { ex.submit(_run_one, a): i for i, a in enumerate(args_list) } for future in as_completed(future_to_idx): idx = future_to_idx[future] try: results[idx] = future.result() except Exception as e: results[idx] = {"error": str(e), "params": samples[idx]} return results except Exception: warnings.warn( "Parallel execution failed, falling back to sequential.", RuntimeWarning, ) return [_run_one(a) for a in args_list] def _aggregate( self, raw: List[Dict], samples: List[Dict], ) -> SensitivityResult: """Build SensitivityResult from list of raw run dicts.""" good = [r for r in raw if "error" not in r] n_fail = len(raw) - len(good) if not good: raise RuntimeError( f"All {len(raw)} model runs failed. " f"Check your parameter distributions and fixed params." ) # Common time array use first successful run's t t_ref = good[0]["t"] # Interpolate all trajectories onto t_ref comp_names = list(good[0]["compartments"].keys()) all_traj: Dict[str, List[np.ndarray]] = {c: [] for c in comp_names} for run in good: t_run = run["t"] for c in comp_names: interp = np.interp(t_ref, t_run, run["compartments"][c]) all_traj[c].append(interp) # Percentile envelopes percentiles = [5, 25, 50, 75, 95] envelopes: Dict[str, Dict[str, np.ndarray]] = {} for c in comp_names: stack = np.vstack(all_traj[c]) # (n_good, n_timepoints) envelopes[c] = { f"p{p}": np.percentile(stack, p, axis=0) for p in percentiles } # Scalar metrics metrics: Dict[str, np.ndarray] = { "r0": np.array([r["r0"] for r in good if r.get("r0") is not None]), "peak_infected": np.array([r["peak_infected"] for r in good if r.get("peak_infected") is not None]), "peak_time": np.array([r["peak_time"] for r in good if r.get("peak_time") is not None]), "final_size": np.array([r["final_size"] for r in good if r.get("final_size") is not None]), } # Param samples that succeeded good_param_samples = [r["params"] for r in good] return SensitivityResult( t=t_ref, envelopes=envelopes, metrics=metrics, n_samples=len(good), n_failed=n_fail, param_samples=good_param_samples, compartment_names=comp_names, )
__all__ = ["SensitivityAnalysis", "SensitivityResult"]