Source code for episia.viz.contingency_plot

"""
viz/contingency_plot.py - 2x2 table visualizations for Episia.

Public functions
----------------
    plot_contingency   annotated heatmap + summary metrics
    plot_measures      horizontal CI chart for all association measures
"""

from __future__ import annotations

from typing import Any, List, Optional

import numpy as np

from .plotters import get_plotter, PlotConfig
from .themes.registry import get_palette



# plot_contingency

[docs] def plot_contingency( result: Any, *, title: str = "2×2 Contingency Table", backend: str = "plotly", theme: str = "scientific", config: Optional[PlotConfig] = None, ) -> Any: """ Annotated 2×2 table heatmap with RR, OR, χ² summary. Args: result: Table2x2 instance, or AssociationResult with table metadata. title: Figure title. backend: 'plotly' or 'matplotlib'. theme: Theme name. config: Full PlotConfig override. Returns: Figure object. Example:: from episia.stats.contingency import Table2x2 from episia.viz.contingency_plot import plot_contingency tbl = Table2x2(40, 10, 20, 30) plot_contingency(tbl, title="Exposure A vs Disease B").show() """ if config is None: config = PlotConfig(title=title, theme=theme) return get_plotter(backend).plot_contingency(result, config=config)
# plot_measures
[docs] def plot_measures( result: Any, *, measures: Optional[List[str]] = None, title: str = "Association Measures", backend: str = "plotly", theme: str = "scientific", config: Optional[PlotConfig] = None, ) -> Any: """ Horizontal CI chart for all association measures from a Table2x2. Displays RR, OR, and RD side by side with their confidence intervals. Args: result: Table2x2 or AssociationResult. measures: Subset of measures to display (default: all). title: Figure title. backend: 'plotly' or 'matplotlib'. theme: Theme name. config: Full PlotConfig override. Returns: Figure object. """ # Extract Table2x2 if hasattr(result, "table"): tbl = result.table else: tbl = result # Compute measures rr = tbl.risk_ratio() or_ = tbl.odds_ratio() rd = tbl.risk_difference() rows = [ dict(label="Risk Ratio", est=rr.estimate, lo=rr.ci_lower, hi=rr.ci_upper, null=1.0), dict(label="Odds Ratio", est=or_.estimate, lo=or_.ci_lower, hi=or_.ci_upper, null=1.0), dict(label="Risk Difference", est=rd["estimate"], lo=rd["ci_lower"], hi=rd["ci_upper"], null=0.0), ] if measures: label_map = {r["label"]: r for r in rows} rows = [label_map[m] for m in measures if m in label_map] if config is None: config = PlotConfig(title=title, theme=theme, height=300) pal = get_palette(theme) n = len(rows) y_pos = list(range(n - 1, -1, -1)) if backend == "plotly": import plotly.graph_objects as go from .plotters.plotly_plotter import _layout, _FONT_COLOR fc = _FONT_COLOR.get(theme, "#222222") fig = go.Figure() for row, y in zip(rows, y_pos): # Null line per measure fig.add_shape( type="line", x0=row["null"], x1=row["null"], y0=y - 0.4, y1=y + 0.4, line=dict(color="#cccccc", width=1, dash="dot"), ) # CI bar fig.add_trace(go.Scatter( x=[row["lo"], row["hi"]], y=[y, y], mode="lines", line=dict(color=pal[0], width=4), showlegend=False, hoverinfo="skip", )) # Point estimate fig.add_trace(go.Scatter( x=[row["est"]], y=[y], mode="markers", marker=dict(color=pal[1], size=12, symbol="diamond"), name=row["label"], hovertemplate=( f"<b>{row['label']}</b><br>" f"Estimate: {row['est']:.3f}<br>" f"95% CI: [{row['lo']:.3f}, {row['hi']:.3f}]" "<extra></extra>" ), showlegend=False, )) fig.update_layout(_layout( config, yaxis=dict( tickvals=y_pos, ticktext=[r["label"] for r in rows], showgrid=False, zeroline=False, color=fc, ), xaxis_title="Estimate", )) return fig else: import matplotlib.pyplot as plt from .themes.registry import apply_mpl_theme apply_mpl_theme(theme) fig, ax = plt.subplots( figsize=(config.width / 100, config.height / 100), facecolor="white", ) for row, y in zip(rows, y_pos): ax.axvline(row["null"], color="#cccccc", linewidth=0.8, linestyle="--", ymin=(y - 0.4) / n, ymax=(y + 0.4) / n) ax.plot([row["lo"], row["hi"]], [y, y], color=pal[0], linewidth=4, solid_capstyle="round") ax.plot(row["est"], y, "D", color=pal[1], markersize=10, zorder=5) ax.text(row["hi"] + 0.02, y, f" {row['est']:.3f} [{row['lo']:.3f}, {row['hi']:.3f}]", va="center", fontsize=config.font_size - 2) ax.set_yticks(y_pos) ax.set_yticklabels([r["label"] for r in rows], fontsize=config.font_size - 1) ax.set_xlabel("Estimate", fontsize=config.font_size) ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold") ax.yaxis.grid(False) fig.tight_layout() return fig
__all__ = ["plot_contingency", "plot_measures"]