Source code for episia.viz.plotters.mpl_plotter

"""
viz/plotters/mpl_plotter.py - Matplotlib rendering backend for Episia.

Purpose: publication-quality static figures only.
    - No animations (use PlotlyPlotter for interactive / animated output)
    - Respects .mplstyle theme files
    - Returns matplotlib Figure objects  saveable at any DPI
    - Suitable for journal submissions, theses, reports

Supported output formats: PNG, SVG, PDF (via fig.savefig)
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from .base_plotter import (
    AnimationConfig,
    AnimationType,
    BasePlotter,
    OutputFormat,
    PlotConfig,
    UnsupportedAnimationError,
)



# Theme helpers

# Colour palettes mirroring plotly_plotter for consistency
_PALETTES: Dict[str, List[str]] = {
    "scientific":  ["#1f77b4", "#d62728", "#2ca02c", "#ff7f0e", "#9467bd",
                    "#8c564b", "#e377c2", "#7f7f7f"],
    "minimal":     ["#333333", "#888888", "#bbbbbb", "#555555", "#aaaaaa"],
    "dark":        ["#64b5f6", "#ef5350", "#66bb6a", "#ffa726", "#ab47bc",
                    "#26c6da", "#d4e157", "#ff7043"],
    "colorblind":  ["#0072B2", "#E69F00", "#56B4E9", "#009E73", "#F0E442",
                    "#D55E00", "#CC79A7", "#999999"],
}

_BG: Dict[str, str] = {
    "scientific": "white",
    "minimal":    "white",
    "dark":       "#1e1e2e",
    "colorblind": "white",
}

_FONT_COLOR: Dict[str, str] = {
    "scientific": "#222222",
    "minimal":    "#333333",
    "dark":       "#eeeeee",
    "colorblind": "#222222",
}

# Path to theme files relative to this file
import os as _os
_THEME_DIR = _os.path.join(_os.path.dirname(__file__), "..", "themes")


def _apply_theme(theme: str) -> None:
    """Apply .mplstyle file if it exists, else use a safe built-in fallback."""
    import matplotlib as mpl
    style_path = _os.path.join(_THEME_DIR, f"{theme}.mplstyle")
    if _os.path.isfile(style_path) and _os.path.getsize(style_path) > 0:
        mpl.style.use(style_path)
    else:
        # Safe fallback for empty/missing style files
        fallback = {
            "scientific": "seaborn-v0_8-paper",
            "minimal":    "seaborn-v0_8-whitegrid",
            "dark":       "dark_background",
            "colorblind": "seaborn-v0_8-colorblind",
        }
        try:
            mpl.style.use(fallback.get(theme, "default"))
        except Exception:
            mpl.style.use("default")


def _palette(cfg: PlotConfig) -> List[str]:
    return cfg.palette or _PALETTES.get(cfg.theme, _PALETTES["scientific"])


def _px_to_in(px: int, dpi: int = 100) -> float:
    return px / dpi


def _style_axes(ax, cfg: PlotConfig) -> None:
    """Apply common axes styling."""
    fc = _FONT_COLOR.get(cfg.theme, "#222222")
    ax.set_facecolor(_BG.get(cfg.theme, "white"))
    ax.tick_params(colors=fc, labelsize=cfg.font_size - 1)
    for spine in ax.spines.values():
        spine.set_edgecolor("#cccccc" if cfg.theme != "dark" else "#444466")
    if not cfg.show_grid:
        ax.grid(False)
    else:
        ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6,
                color="#dddddd" if cfg.theme != "dark" else "#333355")
        ax.set_axisbelow(True)



# MatplotlibPlotter

[docs] class MatplotlibPlotter(BasePlotter): """ Matplotlib rendering backend static, publication-quality figures. Returns matplotlib.figure.Figure objects. No animations are supported. Use PlotlyPlotter for animated output. Call .save() or fig.savefig(path, dpi=300) to export. """ BACKEND_NAME = "matplotlib" SUPPORTED_ANIMATIONS: Tuple[AnimationType, ...] = () # none # epicurve
[docs] def plot_epicurve( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ Epidemic curve bar chart suitable for publication. Trend line overlaid if available. """ import matplotlib.pyplot as plt cfg = self._resolve_config(config) self._check_animation(cfg) # raises if animation requested _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") fig, ax = plt.subplots( figsize=(_px_to_in(cfg.width), _px_to_in(cfg.height)), facecolor=_BG.get(cfg.theme, "white"), ) times = list(result.times) values = list(result.values) ax.bar(times, values, color=pal[0], alpha=0.85, edgecolor="none", label="Cases") if result.trend is not None: ax.plot(times, list(result.trend), color=pal[1], linewidth=2, linestyle="--", label=result.trend_method or "Trend") ax.set_xlabel(cfg.xlabel or "Period", color=fc, fontsize=cfg.font_size) ax.set_ylabel(cfg.ylabel or "Cases", color=fc, fontsize=cfg.font_size) if cfg.title: ax.set_title(cfg.title, color=fc, fontsize=cfg.font_size + 2, fontweight="bold") if cfg.subtitle: ax.set_title(f"{cfg.title}\n{cfg.subtitle}", color=fc, fontsize=cfg.font_size + 2) if cfg.show_legend and result.trend is not None: ax.legend(fontsize=cfg.font_size - 1) _style_axes(ax, cfg) fig.tight_layout() return fig
# model
[docs] def plot_model( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ Compartmental model trajectories clean multi-line plot. R₀ and peak annotations included. """ import matplotlib.pyplot as plt cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") fig, ax = plt.subplots( figsize=(_px_to_in(cfg.width), _px_to_in(cfg.height)), facecolor=_BG.get(cfg.theme, "white"), ) t = result.t for i, (name, arr) in enumerate(result.compartments.items()): ax.plot(t, arr, color=pal[i % len(pal)], linewidth=2.2, label=name) # Peak annotation if result.peak_infected is not None and result.peak_time is not None: ax.axvline(result.peak_time, color="#999999", linewidth=1, linestyle=":", alpha=0.8) ax.annotate( f"Peak: {result.peak_infected:,.0f}\nt={result.peak_time:.1f}", xy=(result.peak_time, result.peak_infected), xytext=(result.peak_time + (t[-1] - t[0]) * 0.04, result.peak_infected * 0.92), fontsize=cfg.font_size - 2, color=fc, arrowprops=dict(arrowstyle="->", color="#999999", lw=0.8), ) # R0 box if result.r0 is not None: ax.text( 0.02, 0.97, f"R₀ = {result.r0:.2f}", transform=ax.transAxes, fontsize=cfg.font_size - 1, verticalalignment="top", color=fc, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#cccccc", alpha=0.8), ) ax.set_xlabel(cfg.xlabel or "Time", color=fc, fontsize=cfg.font_size) ax.set_ylabel(cfg.ylabel or "Population", color=fc, fontsize=cfg.font_size) ax.set_title(cfg.title or f"{result.model_type} Model", color=fc, fontsize=cfg.font_size + 2, fontweight="bold") if cfg.show_legend: ax.legend(fontsize=cfg.font_size - 1, framealpha=0.8) _style_axes(ax, cfg) fig.tight_layout() return fig
# ROC
[docs] def plot_roc( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ ROC curve square axes, AUC annotation, optimal threshold marker. Publication-ready with equal aspect ratio. """ import matplotlib.pyplot as plt from matplotlib.lines import Line2D cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") size = _px_to_in(min(cfg.width, cfg.height)) fig, ax = plt.subplots(figsize=(size, size), facecolor=_BG.get(cfg.theme, "white")) # Fill under curve ax.fill_between(result.fpr, result.tpr, alpha=0.08, color=pal[0]) ax.plot(result.fpr, result.tpr, color=pal[0], linewidth=2.5, label=f"AUC = {result.auc:.3f}") # Reference diagonal ax.plot([0, 1], [0, 1], color="#aaaaaa", linewidth=1, linestyle="--", label="Random") # Optimal threshold opt_fpr = 1 - result.optimal_point.get("specificity", 0) opt_tpr = result.optimal_point.get("sensitivity", 0) ax.scatter([opt_fpr], [opt_tpr], color=pal[1], s=80, zorder=5, label=f"Optimal (t={result.optimal_threshold:.3f})") ax.annotate( f" Sens={opt_tpr:.3f}\n Spec={1-opt_fpr:.3f}", xy=(opt_fpr, opt_tpr), xytext=(opt_fpr + 0.05, opt_tpr - 0.07), fontsize=cfg.font_size - 2, color=fc, ) ax.set_xlim(0, 1) ax.set_ylim(0, 1.02) ax.set_aspect("equal") ax.set_xlabel(cfg.xlabel or "False Positive Rate (1 − Specificity)", color=fc, fontsize=cfg.font_size) ax.set_ylabel(cfg.ylabel or "True Positive Rate (Sensitivity)", color=fc, fontsize=cfg.font_size) ax.set_title(cfg.title or "ROC Curve", color=fc, fontsize=cfg.font_size + 2, fontweight="bold") if cfg.show_legend: ax.legend(fontsize=cfg.font_size - 1, loc="lower right", framealpha=0.85) _style_axes(ax, cfg) fig.tight_layout() return fig
# forest
[docs] def plot_forest( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ Forest plot horizontal CI lines, suitable for meta-analysis tables. """ import matplotlib.pyplot as plt import matplotlib.patches as mpatches cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") # Collect rows (same logic as PlotlyPlotter) rows: List[Dict] = [] if hasattr(result, "stratum_results") and result.stratum_results: for s in result.stratum_results: rows.append(dict( label=s.metadata.get("label", s.measure), est=s.estimate, lo=s.ci.lower, hi=s.ci.upper, p=s.p_value, pooled=False, )) rows.append(dict( label="Pooled (MH)", est=result.mh_estimate, lo=result.ci.lower, hi=result.ci.upper, p=result.p_value, pooled=True, )) elif hasattr(result, "coefficients"): for var, coef in result.coefficients.items(): lo, hi = result.ci_table.get(var, (coef, coef)) rows.append(dict( label=var, est=coef, lo=lo, hi=hi, p=result.p_values.get(var), pooled=False, )) else: rows.append(dict( label=getattr(result, "measure", "estimate"), est=result.estimate, lo=result.ci.lower, hi=result.ci.upper, p=result.p_value, pooled=False, )) n = len(rows) row_height = max(0.45, 4.5 / n) fig_h = max(3.5, n * row_height + 1.2) fig, ax = plt.subplots( figsize=(_px_to_in(cfg.width), fig_h), facecolor=_BG.get(cfg.theme, "white"), ) null_val = 1.0 if all(r["est"] > 0.01 for r in rows) else 0.0 ax.axvline(null_val, color="#999999", linewidth=1, linestyle="--", alpha=0.7) for i, row in enumerate(rows): y = n - 1 - i color = pal[1] if row["pooled"] else pal[0] marker = "D" if row["pooled"] else "s" ms = 9 if row["pooled"] else 7 # CI whisker ax.plot([row["lo"], row["hi"]], [y, y], color=color, linewidth=1.8, solid_capstyle="round") # Point estimate ax.plot(row["est"], y, marker=marker, color=color, markersize=ms, zorder=5) # Separator before pooled if row["pooled"]: ax.axhline(y + 0.5, color="#cccccc", linewidth=0.8, linestyle="-") # p-value text if row.get("p") is not None: pv = row["p"] p_str = "p<0.001" if pv < 0.001 else f"p={pv:.3f}" ax.text(row["hi"] * 1.02, y, f" {p_str}", va="center", fontsize=cfg.font_size - 2, color=fc) ax.set_yticks(range(n)) ax.set_yticklabels([r["label"] for r in reversed(rows)], fontsize=cfg.font_size - 1, color=fc) ax.set_xlabel(cfg.xlabel or "Estimate", color=fc, fontsize=cfg.font_size) ax.set_title(cfg.title or "Forest Plot", color=fc, fontsize=cfg.font_size + 2, fontweight="bold") ax.set_ylim(-0.7, n - 0.3) _style_axes(ax, cfg) ax.yaxis.grid(False) fig.tight_layout() return fig
# association
[docs] def plot_association( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ Single association measure horizontal CI with reference line. Compact figure, suitable as an inline element in a report. """ import matplotlib.pyplot as plt cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") label = result.measure.replace("_", " ").title() fig, ax = plt.subplots( figsize=(_px_to_in(cfg.width), 1.8), facecolor=_BG.get(cfg.theme, "white"), ) # CI whisker ax.plot([result.ci.lower, result.ci.upper], [0, 0], color=pal[0], linewidth=4, solid_capstyle="round") # Point estimate ax.plot(result.estimate, 0, "D", color=pal[1], markersize=12, zorder=5) # Null reference ax.axvline(result.null_value, color="#999999", linewidth=1, linestyle="--") # Annotation p_str = "" if result.p_value is not None: p_str = ("p<0.001" if result.p_value < 0.001 else f"p={result.p_value:.3f}") ci_str = (f"{int(result.ci.confidence*100)}% CI " f"[{result.ci.lower:.3f}, {result.ci.upper:.3f}]") sig = "Significant" if result.significant else "NS" ax.text( 0.99, 0.5, f"{result.estimate:.3f} {ci_str}\n{p_str} {sig}", transform=ax.transAxes, ha="right", va="center", fontsize=cfg.font_size - 1, color=fc, ) ax.set_yticks([]) ax.set_xlabel(cfg.xlabel or label, color=fc, fontsize=cfg.font_size) ax.set_title(cfg.title or label, color=fc, fontsize=cfg.font_size + 1, fontweight="bold") _style_axes(ax, cfg) ax.yaxis.set_visible(False) for spine in ["left", "top", "right"]: ax.spines[spine].set_visible(False) fig.tight_layout() return fig
# diagnostic
[docs] def plot_diagnostic( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ Diagnostic dashboard: confusion matrix heatmap + metrics bar chart. Two-panel layout, publication-ready. """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") bg = _BG.get(cfg.theme, "white") fig = plt.figure( figsize=(_px_to_in(cfg.width), _px_to_in(cfg.height)), facecolor=bg, ) gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1.4], figure=fig) ax_cm = fig.add_subplot(gs[0]) ax_bar = fig.add_subplot(gs[1]) ax_cm.set_facecolor(bg) ax_bar.set_facecolor(bg) # Confusion matrix heatmap cm = np.array([ [result.tn, result.fp], [result.fn, result.tp], ], dtype=float) im = ax_cm.imshow(cm, cmap="Blues", aspect="auto") labels = [["TN", "FP"], ["FN", "TP"]] for r in range(2): for c in range(2): ax_cm.text(c, r, f"{labels[r][c]}\n{int(cm[r, c])}", ha="center", va="center", fontsize=cfg.font_size, color="white" if cm[r, c] > cm.max() * 0.6 else fc) ax_cm.set_xticks([0, 1]) ax_cm.set_xticklabels(["Pred Neg", "Pred Pos"], fontsize=cfg.font_size - 1, color=fc) ax_cm.set_yticks([0, 1]) ax_cm.set_yticklabels(["Actual Neg", "Actual Pos"], fontsize=cfg.font_size - 1, color=fc) ax_cm.set_title("Confusion Matrix", color=fc, fontsize=cfg.font_size, fontweight="bold") ax_cm.tick_params(colors=fc) # Metrics bar chart metrics = { "Sensitivity": result.sensitivity, "Specificity": result.specificity, "PPV": result.ppv, "NPV": result.npv, "Accuracy": result.accuracy, "Youden J": result.youden, } m_labels = list(metrics.keys()) m_values = list(metrics.values()) colors = [pal[i % len(pal)] for i in range(len(m_labels))] bars = ax_bar.barh(m_labels, m_values, color=colors, edgecolor="none", height=0.6) for bar, val in zip(bars, m_values): ax_bar.text( min(val + 0.02, 0.98), bar.get_y() + bar.get_height() / 2, f"{val:.3f}", va="center", fontsize=cfg.font_size - 1, color=fc, ) ax_bar.set_xlim(0, 1.15) ax_bar.set_xlabel("Value", color=fc, fontsize=cfg.font_size) ax_bar.set_title("Performance Metrics", color=fc, fontsize=cfg.font_size, fontweight="bold") ax_bar.tick_params(colors=fc) ax_bar.invert_yaxis() _style_axes(ax_bar, cfg) ax_bar.yaxis.grid(False) if cfg.title: fig.suptitle(cfg.title, fontsize=cfg.font_size + 3, fontweight="bold", color=fc) fig.tight_layout() return fig
# contingency
[docs] def plot_contingency( self, result: Any, config: Optional[PlotConfig] = None, ) -> Any: """ 2x2 contingency table annotated heatmap with summary table. """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec cfg = self._resolve_config(config) self._check_animation(cfg) _apply_theme(cfg.theme) pal = _palette(cfg) fc = _FONT_COLOR.get(cfg.theme, "#222222") bg = _BG.get(cfg.theme, "white") if hasattr(result, "table"): tbl = result.table else: tbl = result a, b, c, d = tbl.a, tbl.b, tbl.c, tbl.d total = a + b + c + d fig = plt.figure( figsize=(_px_to_in(cfg.width), _px_to_in(cfg.height)), facecolor=bg, ) gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1.2], figure=fig, wspace=0.35) ax_tbl = fig.add_subplot(gs[0]) ax_text = fig.add_subplot(gs[1]) # Heatmap cells = np.array([[d, c], [b, a]], dtype=float) cell_labels = [ [f"TN (d)\n{d}\n{d/total:.1%}", f"Exp Non-cases (c)\n{c}\n{c/total:.1%}"], [f"Unexp Cases (b)\n{b}\n{b/total:.1%}", f"Exp Cases (a)\n{a}\n{a/total:.1%}"], ] ax_tbl.imshow(cells, cmap="Blues", aspect="auto") for r in range(2): for col in range(2): ax_tbl.text(col, r, cell_labels[r][col], ha="center", va="center", fontsize=cfg.font_size - 2, color="white" if cells[r, col] > cells.max() * 0.6 else fc) ax_tbl.set_xticks([0, 1]) ax_tbl.set_xticklabels(["Unexposed", "Exposed"], fontsize=cfg.font_size - 1, color=fc) ax_tbl.set_yticks([0, 1]) ax_tbl.set_yticklabels(["Non-cases", "Cases"], fontsize=cfg.font_size - 1, color=fc) ax_tbl.set_title("2×2 Table", color=fc, fontsize=cfg.font_size + 1, fontweight="bold") ax_tbl.tick_params(colors=fc) # Summary text panel rr = tbl.risk_ratio() or_ = tbl.odds_ratio() chi = tbl.chi_square() summary_lines = [ ("Risk (exposed)", f"{tbl.risk_exposed:.4f}"), ("Risk (unexposed)", f"{tbl.risk_unexposed:.4f}"), ("Risk Ratio", f"{rr.estimate:.3f} [{rr.ci_lower:.3f}{rr.ci_upper:.3f}]"), ("Odds Ratio", f"{or_.estimate:.3f} [{or_.ci_lower:.3f}{or_.ci_upper:.3f}]"), ("χ² p-value", f"{chi['p_value']:.4f}"), ("N total", f"{total}"), ] ax_text.axis("off") ax_text.set_facecolor(bg) y_start = 0.92 line_h = 0.13 for i, (label, value) in enumerate(summary_lines): y = y_start - i * line_h ax_text.text(0.02, y, label + ":", transform=ax_text.transAxes, fontsize=cfg.font_size - 1, color="#888888", va="top") ax_text.text(0.55, y, value, transform=ax_text.transAxes, fontsize=cfg.font_size - 1, color=fc, va="top", fontweight="bold") ax_text.set_title("Summary", color=fc, fontsize=cfg.font_size + 1, fontweight="bold") if cfg.title: fig.suptitle(cfg.title, fontsize=cfg.font_size + 3, fontweight="bold", color=fc, y=1.01) fig.tight_layout() return fig
# save
[docs] def save( self, fig: Any, path: str, fmt: OutputFormat = OutputFormat.PNG, dpi: int = 300, ) -> str: """ Save a Matplotlib figure at publication quality. Default DPI is 300 (journal standard). Supports PNG, SVG, PDF. """ import os if fmt in (OutputFormat.HTML, OutputFormat.JSON, OutputFormat.GIF, OutputFormat.MP4): raise UnsupportedAnimationError( f"MatplotlibPlotter.save() does not support {fmt.value}. " "Use PlotlyPlotter for HTML/JSON output." ) ext = f".{fmt.value}" if not path.endswith(ext): path = path + ext path = os.path.abspath(path) fig.savefig(path, dpi=dpi, bbox_inches="tight", facecolor=fig.get_facecolor()) return path
# Exports __all__ = ["MatplotlibPlotter"]