Source code for episia.models.scenarios

"""
models/scenarios.py - Multi-scenario runner for compartmental models.

Runs a set of parameter scenarios through a model class and returns
a structured comparison  ready for plotting or tabular export.

Public classes
--------------
    ScenarioRunner   runs N scenarios and returns ScenarioResults
    ScenarioResults  container for multi-scenario comparison
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type

import numpy as np

from .base import CompartmentalModel
from .parameters import ModelParameters, ScenarioSet


[docs] @dataclass class ScenarioResults: """ Container for results from multiple model scenarios. Attributes: labels: Scenario names. results: Corresponding ModelResult objects. metrics_df: pandas DataFrame of key metrics (lazy, call .to_dataframe()). """ labels: List[str] results: List[Any] # List[ModelResult]
[docs] def __len__(self) -> int: return len(self.labels)
[docs] def __iter__(self): return iter(zip(self.labels, self.results))
[docs] def to_dataframe(self): """Return a pandas DataFrame comparing key metrics across scenarios.""" import pandas as pd rows = [] for label, res in zip(self.labels, self.results): row = {"scenario": label} row.update({ "r0": res.r0, "peak_infected": res.peak_infected, "peak_time": res.peak_time, "final_size": res.final_size, }) # Optional SEIRD metrics if "total_deaths" in res.metadata.get("metrics", {}): row["total_deaths"] = res.metadata["metrics"]["total_deaths"] row["cfr"] = res.metadata["metrics"].get("cfr") rows.append(row) return pd.DataFrame(rows).set_index("scenario")
[docs] def plot( self, compartment: str = "I", backend: str = "plotly", theme: str = "scientific", title: str = "Scenario Comparison", ) -> Any: """ Overlay trajectories for all scenarios on a single figure. Args: compartment: Compartment to plot (e.g. 'I', 'D', 'R'). backend: 'plotly' or 'matplotlib'. theme: Theme name. title: Figure title. Returns: Figure object. """ from ..viz.themes.registry import get_palette from ..viz.plotters import PlotConfig, get_plotter pal = get_palette(theme) if backend == "plotly": import plotly.graph_objects as go from ..viz.plotters.plotly_plotter import _layout config = PlotConfig(title=title, theme=theme, xlabel="Time (days)", ylabel=f"Individuals in {compartment}") fig = go.Figure() for i, (label, res) in enumerate(zip(self.labels, self.results)): comp = res.compartments.get(compartment) if comp is None: continue fig.add_trace(go.Scatter( x=list(res.t), y=list(comp), mode="lines", name=f"{label} (R₀={res.r0:.2f})", line=dict(color=pal[i % len(pal)], width=2.2), )) fig.update_layout(_layout(config)) return fig else: import matplotlib.pyplot as plt from ..viz.themes.registry import apply_mpl_theme apply_mpl_theme(theme) fig, ax = plt.subplots(figsize=(10, 5), facecolor="white") for i, (label, res) in enumerate(zip(self.labels, self.results)): comp = res.compartments.get(compartment) if comp is None: continue ax.plot(res.t, comp, color=pal[i % len(pal)], linewidth=2.2, label=f"{label} (R₀={res.r0:.2f})") ax.set_xlabel("Time (days)", fontsize=12) ax.set_ylabel(f"Individuals in {compartment}", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.legend(fontsize=10) fig.tight_layout() return fig
[docs] def __repr__(self) -> str: return ( f"ScenarioResults({len(self.labels)} scenarios: " f"{self.labels})" )
[docs] class ScenarioRunner: """ Run multiple parameter scenarios through a single model class. Example:: from episia.models.parameters import SIRParameters, ScenarioSet from episia.models.scenarios import ScenarioRunner from episia.models.sir import SIRModel scenarios = ScenarioSet([ ("R0=1.5", SIRParameters(N=1_000_000, I0=10, beta=0.15, gamma=0.1)), ("R0=2.5", SIRParameters(N=1_000_000, I0=10, beta=0.25, gamma=0.1)), ("R0=3.5", SIRParameters(N=1_000_000, I0=10, beta=0.35, gamma=0.1)), ]) runner = ScenarioRunner(SIRModel) results = runner.run(scenarios) results.to_dataframe() results.plot(compartment="I").show() """
[docs] def __init__( self, model_class: Type[CompartmentalModel], solver_kwargs: Optional[Dict[str, Any]] = None, ): """ Args: model_class: Model class to instantiate (SIRModel, SEIRModel…). solver_kwargs: Extra kwargs forwarded to model.run(). """ self.model_class = model_class self.solver_kwargs = solver_kwargs or {}
[docs] def run(self, scenarios: ScenarioSet) -> ScenarioResults: """ Run all scenarios and return a ScenarioResults container. Args: scenarios: ScenarioSet of (label, parameters) pairs. Returns: ScenarioResults ready for plotting or DataFrame export. """ labels = [] results = [] for label, params in scenarios: model = self.model_class(params) result = model.run(**self.solver_kwargs) labels.append(label) results.append(result) return ScenarioResults(labels=labels, results=results)
__all__ = ["ScenarioRunner", "ScenarioResults"]