Source code for episia.models.base

"""
models/base.py - Abstract base class for all compartmental epidemic models.

Every model (SIR, SEIR, SEIRD, custom) inherits from CompartmentalModel
and shares a unified interface:
    - run()       → ModelResult
    - plot()      → Figure (delegates to viz layer)
    - summary()   → dict of key epidemiological metrics
    - to_dataframe() → pandas DataFrame of trajectories
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple

import numpy as np


[docs] class CompartmentalModel(ABC): """ Abstract base class for compartmental epidemic models. Subclasses must implement: compartment_names property listing compartment names (e.g. ['S','I','R']) _derivatives() ODE right-hand side _initial_state() initial conditions vector from parameters _compute_metrics() derive R0, peak, final_size from solution The run() method is fully implemented here via _derivatives() and the shared solver in models.solver. """
[docs] def __init__(self, parameters: "ModelParameters"): # noqa: F821 self.parameters = parameters self._result: Optional[Any] = None # cached ModelResult after run()
# abstract @property @abstractmethod def compartment_names(self) -> List[str]: """Ordered list of compartment names, e.g. ['S', 'I', 'R'].""" ... @abstractmethod def _derivatives(self, t: float, y: np.ndarray) -> np.ndarray: """ ODE right-hand side. Args: t: Current time. y: State vector (same order as compartment_names). Returns: dy/dt vector. """ ... @abstractmethod def _initial_state(self) -> np.ndarray: """ Build initial conditions vector from self.parameters. Returns: 1-D array with one value per compartment. """ ... @abstractmethod def _compute_metrics( self, t: np.ndarray, solution: np.ndarray, ) -> Dict[str, Any]: """ Compute epidemiological metrics from solved trajectory. Args: t: Time array. solution: (n_compartments × n_timepoints) array. Returns: Dict with at least: r0, peak_infected, peak_time, final_size. """ ... # run
[docs] def run( self, t_span: Optional[Tuple[float, float]] = None, t_eval: Optional[np.ndarray] = None, method: str = "RK45", rtol: float = 1e-6, atol: float = 1e-8, ) -> Any: """ Solve the ODE system and return a ModelResult. Args: t_span: (t_start, t_end). Defaults to parameters.t_span. t_eval: Time points at which to store solution. Defaults to np.linspace(t_start, t_end, 1000). method: scipy solve_ivp integration method (default RK45). rtol: Relative tolerance. atol: Absolute tolerance. Returns: ModelResult (also cached as self._result). Raises: RuntimeError: If the solver fails to integrate. """ from .solver import solve_model span = t_span or self.parameters.t_span y0 = self._initial_state() if t_eval is None: n_pts = max(500, int((span[1] - span[0]) * 10)) t_eval = np.linspace(span[0], span[1], n_pts) t, sol = solve_model( derivatives=self._derivatives, y0=y0, t_span=span, t_eval=t_eval, method=method, rtol=rtol, atol=atol, ) metrics = self._compute_metrics(t, sol) self._result = self._build_result(t, sol, metrics) return self._result
# result builder def _build_result( self, t: np.ndarray, sol: np.ndarray, metrics: Dict[str, Any], ) -> Any: """Assemble a ModelResult from solved trajectory and metrics.""" from ..api.results import ModelResult compartments = { name: sol[i] for i, name in enumerate(self.compartment_names) } return ModelResult( model_type=self.__class__.__name__.replace("Model", ""), t=t, compartments=compartments, parameters=self.parameters.to_dict(), r0=metrics.get("r0"), peak_infected=metrics.get("peak_infected"), peak_time=metrics.get("peak_time"), final_size=metrics.get("final_size"), ) # convenience
[docs] def plot(self, backend: str = "plotly", **kwargs) -> Any: """ Plot model trajectories. Runs the model first if not already run. Args: backend: 'plotly' or 'matplotlib'. **kwargs: Passed to PlotConfig. Returns: Figure object. """ if self._result is None: self.run() return self._result.plot(backend=backend, **kwargs)
[docs] def summary(self) -> Dict[str, Any]: """ Return key epidemiological metrics. Runs the model first if not already run. """ if self._result is None: self.run() return { "model": self.__class__.__name__, "r0": self._result.r0, "peak_infected": self._result.peak_infected, "peak_time": self._result.peak_time, "final_size": self._result.final_size, "parameters": self.parameters.to_dict(), }
[docs] def to_dataframe(self): """ Return trajectory as pandas DataFrame. Runs the model first if not already run. """ if self._result is None: self.run() return self._result.to_dataframe()
[docs] def __repr__(self) -> str: status = "ready" if self._result is None else "solved" return ( f"{self.__class__.__name__}(" f"N={self.parameters.N:,}, " f"status={status})" )
__all__ = ["CompartmentalModel"]