"""
viz/plotters/plotly_plotter.py - Plotly rendering backend for Episia.
Default backend produces interactive HTML figures suitable for:
- Notebooks (Jupyter / JupyterLab)
- Web frontends (React via plotly.js / JSON serialization)
- Standalone HTML exports
Supported animations
FRAME_BY_FRAME plot_epicurve, plot_forest, plot_diagnostic
CONTINUOUS plot_model, plot_roc
PLAY_PAUSE plot_epicurve, plot_model
SLIDER plot_model (parameter sweep)
All plot methods accept an optional PlotConfig. If omitted, the
instance default_config is used.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .base_plotter import (
AnimationConfig,
AnimationType,
BasePlotter,
OutputFormat,
PlotConfig,
UnsupportedAnimationError,
)
# Colour palettes per theme
_PALETTES: Dict[str, List[str]] = {
"scientific": ["#1f77b4", "#d62728", "#2ca02c", "#ff7f0e", "#9467bd",
"#8c564b", "#e377c2", "#7f7f7f"],
"minimal": ["#333333", "#888888", "#bbbbbb", "#555555", "#aaaaaa"],
"dark": ["#64b5f6", "#ef5350", "#66bb6a", "#ffa726", "#ab47bc",
"#26c6da", "#d4e157", "#ff7043"],
"colorblind": ["#0072B2", "#E69F00", "#56B4E9", "#009E73", "#F0E442",
"#D55E00", "#CC79A7", "#999999"],
}
_BG: Dict[str, str] = {
"scientific": "#ffffff",
"minimal": "#ffffff",
"dark": "#1e1e2e",
"colorblind": "#ffffff",
}
_GRID: Dict[str, str] = {
"scientific": "#e5e5e5",
"minimal": "#f0f0f0",
"dark": "#333355",
"colorblind": "#e5e5e5",
}
_FONT_COLOR: Dict[str, str] = {
"scientific": "#222222",
"minimal": "#333333",
"dark": "#eeeeee",
"colorblind": "#222222",
}
# Helper: build Plotly layout dict
def _layout(cfg: PlotConfig, **overrides) -> Dict:
theme = cfg.theme
bg = _BG.get(theme, "#ffffff")
grid_c = _GRID.get(theme, "#e5e5e5")
font_c = _FONT_COLOR.get(theme, "#222222")
title_text = cfg.title
if cfg.subtitle:
title_text += f"<br><sup>{cfg.subtitle}</sup>"
base = dict(
title=dict(text=title_text, font=dict(size=cfg.font_size + 3, color=font_c)),
xaxis=dict(
title=cfg.xlabel,
showgrid=cfg.show_grid,
gridcolor=grid_c,
zeroline=False,
color=font_c,
),
yaxis=dict(
title=cfg.ylabel,
showgrid=cfg.show_grid,
gridcolor=grid_c,
zeroline=False,
color=font_c,
),
plot_bgcolor=bg,
paper_bgcolor=bg,
font=dict(size=cfg.font_size, color=font_c),
legend=dict(visible=cfg.show_legend, bgcolor="rgba(0,0,0,0)"),
width=cfg.width,
height=cfg.height,
margin=dict(l=60, r=30, t=70, b=60),
)
base.update(overrides)
return base
def _palette(cfg: PlotConfig) -> List[str]:
if cfg.palette:
return cfg.palette
return _PALETTES.get(cfg.theme, _PALETTES["scientific"])
# Helper: build animation frames + layout buttons
def _make_play_pause_buttons(anim: AnimationConfig) -> List[Dict]:
"""Return updatemenus list with Play / Pause buttons."""
return [
dict(
type="buttons",
showactive=False,
y=0,
x=0.5,
xanchor="center",
yanchor="top",
buttons=[
dict(
label="Play",
method="animate",
args=[
None,
dict(
frame=dict(duration=anim.frame_ms, redraw=True),
transition=dict(duration=anim.transition_ms,
easing=anim.easing),
fromcurrent=True,
loop=anim.loop,
),
],
),
dict(
label="Pause",
method="animate",
args=[[None], dict(frame=dict(duration=0, redraw=False),
mode="immediate")],
),
],
)
]
def _make_slider(labels: List[str], anim: AnimationConfig,
prefix: str = "t=") -> Dict:
"""Return a Plotly slider dict over animation frames."""
steps = [
dict(
args=[[str(i)],
dict(frame=dict(duration=anim.frame_ms, redraw=True),
mode="immediate",
transition=dict(duration=anim.transition_ms))],
label=str(lab),
method="animate",
)
for i, lab in enumerate(labels)
]
return dict(
active=0,
currentvalue=dict(prefix=prefix, visible=True, xanchor="center"),
pad=dict(b=10, t=50),
steps=steps,
)
# PlotlyPlotter
[docs]
class PlotlyPlotter(BasePlotter):
"""
Plotly rendering backend.
Returns plotly.graph_objects.Figure objects.
Call .show() to display, .to_json() to serialize for React/JS.
Animations supported
--
FRAME_BY_FRAME : plot_epicurve, plot_forest, plot_diagnostic
CONTINUOUS : plot_model, plot_roc
PLAY_PAUSE : plot_epicurve, plot_model
SLIDER : plot_model
"""
BACKEND_NAME = "plotly"
SUPPORTED_ANIMATIONS = (
AnimationType.FRAME_BY_FRAME,
AnimationType.CONTINUOUS,
AnimationType.PLAY_PAUSE,
AnimationType.SLIDER,
)
# epicurve
[docs]
def plot_epicurve(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Epidemic curve bar chart of cases over time.
Animation (FRAME_BY_FRAME / PLAY_PAUSE):
Bars build up period by period from left to right.
"""
import plotly.graph_objects as go
cfg = self._resolve_config(config)
self._check_animation(cfg)
anim = cfg.animation
color = _palette(cfg)[0]
times = list(result.times)
values = list(result.values)
n = len(times)
if not anim.enabled:
traces = [
go.Bar(
x=times, y=values,
marker_color=color,
name="Cases",
)
]
if result.trend is not None:
traces.append(go.Scatter(
x=times, y=list(result.trend),
mode="lines",
line=dict(color=_palette(cfg)[1], width=2, dash="dash"),
name=result.trend_method or "Trend",
))
layout = _layout(
cfg,
xaxis_title=cfg.xlabel or "Period",
yaxis_title=cfg.ylabel or "Cases",
bargap=0.15,
)
return go.Figure(data=traces, layout=layout)
# animated version
# Frame i shows bars for periods 0..i
frames = []
for i in range(n):
frame_data = [
go.Bar(x=times[: i + 1], y=values[: i + 1], marker_color=color)
]
frames.append(go.Frame(data=frame_data, name=str(i)))
fig = go.Figure(
data=[go.Bar(x=times[:1], y=values[:1], marker_color=color)],
frames=frames,
)
layout = _layout(
cfg,
xaxis=dict(
title=cfg.xlabel or "Period",
range=[-0.5, n - 0.5],
showgrid=cfg.show_grid,
),
yaxis=dict(
title=cfg.ylabel or "Cases",
range=[0, max(values) * 1.1],
showgrid=cfg.show_grid,
),
updatemenus=_make_play_pause_buttons(anim) if cfg.show_legend else [],
sliders=[_make_slider([str(t) for t in times], anim, prefix="Period: ")],
)
fig.update_layout(layout)
return fig
# model
[docs]
def plot_model(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Compartmental model trajectories (SIR / SEIR / SEIRD…).
Animation (CONTINUOUS / PLAY_PAUSE / SLIDER):
Lines draw from t=0 forward, one frame per time step.
SLIDER adds an interactive time scrubber.
"""
import plotly.graph_objects as go
cfg = self._resolve_config(config)
self._check_animation(cfg)
anim = cfg.animation
pal = _palette(cfg)
t = list(result.t)
compartments = result.compartments # dict name -> array
names = list(compartments.keys())
n_steps = len(t)
if not anim.enabled:
traces = []
for i, (name, arr) in enumerate(compartments.items()):
traces.append(go.Scatter(
x=t, y=list(arr),
mode="lines",
name=name,
line=dict(color=pal[i % len(pal)], width=2.5),
))
annotations = []
if result.r0 is not None:
annotations.append(dict(
x=0.02, y=0.97, xref="paper", yref="paper",
text=f"R₀ = {result.r0:.2f}",
showarrow=False,
font=dict(size=cfg.font_size, color=_FONT_COLOR.get(cfg.theme)),
bgcolor="rgba(255,255,255,0.7)",
bordercolor="#cccccc",
borderwidth=1,
))
layout = _layout(
cfg,
title=cfg.title or f"{result.model_type} Model",
xaxis_title=cfg.xlabel or "Time",
yaxis_title=cfg.ylabel or "Population",
annotations=annotations,
)
return go.Figure(data=traces, layout=layout)
# animated: one frame per time step -
frames = []
for i in range(1, n_steps + 1):
frame_traces = []
for j, (name, arr) in enumerate(compartments.items()):
frame_traces.append(go.Scatter(
x=t[:i], y=list(arr[:i]),
mode="lines",
name=name,
line=dict(color=pal[j % len(pal)], width=2.5),
))
frames.append(go.Frame(data=frame_traces, name=str(i - 1)))
# Initial frame: first point only
init_traces = [
go.Scatter(
x=t[:1], y=list(arr[:1]),
mode="lines", name=name,
line=dict(color=pal[j % len(pal)], width=2.5),
)
for j, (name, arr) in enumerate(compartments.items())
]
fig = go.Figure(data=init_traces, frames=frames)
extra: Dict = {}
if anim.anim_type == AnimationType.SLIDER:
extra["sliders"] = [_make_slider(
[f"{v:.1f}" for v in t], anim, prefix="t = "
)]
layout = _layout(
cfg,
title=cfg.title or f"{result.model_type} Model",
xaxis=dict(
title=cfg.xlabel or "Time",
range=[t[0], t[-1]],
showgrid=cfg.show_grid,
),
yaxis=dict(
title=cfg.ylabel or "Population",
showgrid=cfg.show_grid,
),
updatemenus=_make_play_pause_buttons(anim),
**extra,
)
fig.update_layout(layout)
return fig
# ROC
[docs]
def plot_roc(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
ROC curve with AUC annotation and optimal threshold marker.
Animation (CONTINUOUS):
The curve traces itself from (0,0) to (1,1) point by point,
simulating a threshold sweep from high to low.
"""
import plotly.graph_objects as go
cfg = self._resolve_config(config)
self._check_animation(cfg)
anim = cfg.animation
pal = _palette(cfg)
fpr = list(result.fpr)
tpr = list(result.tpr)
n = len(fpr)
# Reference diagonal
diag = go.Scatter(
x=[0, 1], y=[0, 1],
mode="lines",
line=dict(color="#aaaaaa", width=1, dash="dot"),
showlegend=False,
hoverinfo="skip",
)
# Optimal threshold marker
opt_fpr = 1 - result.optimal_point.get("specificity", 0)
opt_tpr = result.optimal_point.get("sensitivity", 0)
marker = go.Scatter(
x=[opt_fpr], y=[opt_tpr],
mode="markers+text",
marker=dict(color=pal[1], size=10, symbol="star"),
text=[f" threshold={result.optimal_threshold:.3f}"],
textposition="middle right",
name="Optimal",
textfont=dict(color=_FONT_COLOR.get(cfg.theme)),
)
auc_annotation = dict(
x=0.97, y=0.05, xref="paper", yref="paper",
text=f"AUC = {result.auc:.3f}",
showarrow=False,
font=dict(size=cfg.font_size + 1, color=_FONT_COLOR.get(cfg.theme)),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="#cccccc",
borderwidth=1,
align="right",
)
if not anim.enabled:
roc_trace = go.Scatter(
x=fpr, y=tpr,
mode="lines",
name=f"ROC (AUC={result.auc:.3f})",
line=dict(color=pal[0], width=2.5),
fill="tozeroy",
fillcolor=f"rgba({int(pal[0][1:3],16)},"
f"{int(pal[0][3:5],16)},"
f"{int(pal[0][5:7],16)},0.08)",
)
layout = _layout(
cfg,
title=cfg.title or "ROC Curve",
xaxis_title=cfg.xlabel or "False Positive Rate (1 − Specificity)",
yaxis_title=cfg.ylabel or "True Positive Rate (Sensitivity)",
xaxis=dict(range=[0, 1], constrain="domain"),
yaxis=dict(range=[0, 1.02], scaleanchor="x", scaleratio=1),
annotations=[auc_annotation],
)
return go.Figure(data=[diag, roc_trace, marker], layout=layout)
# animated: threshold sweep
frames = []
for i in range(2, n + 1):
frames.append(go.Frame(
data=[
diag,
go.Scatter(
x=fpr[:i], y=tpr[:i],
mode="lines",
line=dict(color=pal[0], width=2.5),
fill="tozeroy",
fillcolor=f"rgba({int(pal[0][1:3],16)},"
f"{int(pal[0][3:5],16)},"
f"{int(pal[0][5:7],16)},0.08)",
showlegend=False,
),
marker,
],
name=str(i - 2),
))
fig = go.Figure(
data=[diag, go.Scatter(x=fpr[:2], y=tpr[:2], mode="lines",
line=dict(color=pal[0], width=2.5)), marker],
frames=frames,
)
layout = _layout(
cfg,
title=cfg.title or "ROC Curve",
xaxis=dict(title=cfg.xlabel or "False Positive Rate",
range=[0, 1], showgrid=cfg.show_grid),
yaxis=dict(title=cfg.ylabel or "True Positive Rate",
range=[0, 1.02], showgrid=cfg.show_grid),
updatemenus=_make_play_pause_buttons(anim),
annotations=[auc_annotation],
)
fig.update_layout(layout)
return fig
# forest
[docs]
def plot_forest(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Forest plot for stratified (MH) or regression results.
Animation (FRAME_BY_FRAME):
Rows appear one by one from top to bottom.
"""
import plotly.graph_objects as go
cfg = self._resolve_config(config)
self._check_animation(cfg)
anim = cfg.animation
pal = _palette(cfg)
# Collect rows -
rows: List[Dict] = []
# StratifiedResult
if hasattr(result, "stratum_results") and result.stratum_results:
for s in result.stratum_results:
rows.append(dict(
label=s.metadata.get("label", s.measure),
est=s.estimate,
lo=s.ci.lower,
hi=s.ci.upper,
p=s.p_value,
))
# Pooled row
rows.append(dict(
label="Pooled (MH)",
est=result.mh_estimate,
lo=result.ci.lower,
hi=result.ci.upper,
p=result.p_value,
pooled=True,
))
# RegressionResult
elif hasattr(result, "coefficients"):
for var, coef in result.coefficients.items():
lo, hi = result.ci_table.get(var, (coef, coef))
rows.append(dict(
label=var,
est=coef,
lo=lo,
hi=hi,
p=result.p_values.get(var),
))
# Single AssociationResult fallback
else:
rows.append(dict(
label=getattr(result, "measure", "estimate"),
est=result.estimate,
lo=result.ci.lower,
hi=result.ci.upper,
p=result.p_value,
))
n_rows = len(rows)
y_pos = list(range(n_rows - 1, -1, -1)) # top to bottom
def _build_traces(subset_rows, subset_y):
traces = []
for row, y in zip(subset_rows, subset_y):
is_pooled = row.get("pooled", False)
color = pal[1] if is_pooled else pal[0]
size = 14 if is_pooled else 10
symbol = "diamond" if is_pooled else "square"
# CI line
traces.append(go.Scatter(
x=[row["lo"], row["hi"]], y=[y, y],
mode="lines",
line=dict(color=color, width=2),
showlegend=False,
hoverinfo="skip",
))
# Point estimate
if row.get("p") is not None:
_pv = row["p"]
_pf = "<0.001" if _pv < 0.001 else f"{_pv:.3f}"
p_str = f"p={_pf}"
else:
p_str = ""
traces.append(go.Scatter(
x=[row["est"]], y=[y],
mode="markers",
marker=dict(color=color, size=size, symbol=symbol),
name=row["label"],
hovertemplate=(
f"<b>{row['label']}</b><br>"
f"Estimate: {row['est']:.3f}<br>"
f"95% CI: [{row['lo']:.3f}, {row['hi']:.3f}]<br>"
f"{p_str}<extra></extra>"
),
showlegend=False,
))
return traces
# null line (1 for ratios, 0 for differences infer from estimates)
null_value = 1.0 if all(r["est"] > 0.01 for r in rows) else 0.0
base_layout = _layout(
cfg,
title=cfg.title or "Forest Plot",
xaxis_title=cfg.xlabel or "Estimate",
yaxis=dict(
tickvals=y_pos,
ticktext=[r["label"] for r in rows],
showgrid=False,
zeroline=False,
color=_FONT_COLOR.get(cfg.theme),
),
shapes=[dict(
type="line",
x0=null_value, x1=null_value,
y0=-0.5, y1=n_rows - 0.5,
line=dict(color="#999999", width=1, dash="dot"),
)],
)
if not anim.enabled:
traces = _build_traces(rows, y_pos)
return go.Figure(data=traces, layout=base_layout)
# animated: rows appear one by one --
frames = []
for i in range(1, n_rows + 1):
frames.append(go.Frame(
data=_build_traces(rows[:i], y_pos[:i]),
name=str(i - 1),
))
fig = go.Figure(
data=_build_traces(rows[:1], y_pos[:1]),
frames=frames,
)
base_layout["updatemenus"] = _make_play_pause_buttons(anim)
fig.update_layout(base_layout)
return fig
# association
[docs]
def plot_association(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Single association measure horizontal CI plot with reference line.
Static (no animation by design).
"""
import plotly.graph_objects as go
cfg = self._resolve_config(config)
pal = _palette(cfg)
label = result.measure.replace("_", " ").title()
null_value = result.null_value
fig = go.Figure()
# CI bar
fig.add_trace(go.Scatter(
x=[result.ci.lower, result.ci.upper],
y=[label, label],
mode="lines",
line=dict(color=pal[0], width=4),
showlegend=False,
hoverinfo="skip",
))
# Point estimate
p_str = ""
if result.p_value is not None:
p_str = f"p={'<0.001' if result.p_value < 0.001 else f'{result.p_value:.3f}'}"
fig.add_trace(go.Scatter(
x=[result.estimate],
y=[label],
mode="markers",
marker=dict(color=pal[1], size=14, symbol="diamond"),
hovertemplate=(
f"<b>{label}</b><br>"
f"Estimate: {result.estimate:.3f}<br>"
f"{int(result.ci.confidence * 100)}% CI: "
f"[{result.ci.lower:.3f}, {result.ci.upper:.3f}]<br>"
f"{p_str}<extra></extra>"
),
showlegend=False,
))
sig_color = pal[2] if result.significant else "#aaaaaa"
significance = "Significant" if result.significant else "Not significant"
layout = _layout(
cfg,
title=cfg.title or label,
xaxis_title=cfg.xlabel or "Estimate",
yaxis=dict(showgrid=False, zeroline=False,
color=_FONT_COLOR.get(cfg.theme)),
height=min(cfg.height, 250),
shapes=[dict(
type="line",
x0=null_value, x1=null_value,
y0=-0.5, y1=0.5,
line=dict(color="#999999", width=1, dash="dot"),
)],
annotations=[dict(
x=0.99, y=0.95, xref="paper", yref="paper",
text=significance,
showarrow=False,
font=dict(color=sig_color, size=cfg.font_size),
align="right",
)],
)
fig.update_layout(layout)
return fig
# diagnostic
[docs]
def plot_diagnostic(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Diagnostic test dashboard: confusion matrix heatmap + metrics bars.
Animation (FRAME_BY_FRAME):
Metric bars fill in one by one.
"""
import plotly.graph_objects as go
from plotly.subplots import make_subplots
cfg = self._resolve_config(config)
self._check_animation(cfg)
anim = cfg.animation
pal = _palette(cfg)
font_c = _FONT_COLOR.get(cfg.theme, "#222222")
bg = _BG.get(cfg.theme, "#ffffff")
# Confusion matrix
cm_values = [
[result.tn, result.fp],
[result.fn, result.tp],
]
cm_text = [
[f"TN<br>{result.tn}", f"FP<br>{result.fp}"],
[f"FN<br>{result.fn}", f"TP<br>{result.tp}"],
]
cm_trace = go.Heatmap(
z=cm_values,
text=cm_text,
texttemplate="%{text}",
x=["Predicted Neg", "Predicted Pos"],
y=["Actual Neg", "Actual Pos"],
colorscale=[[0, "#d6e8ff"], [1, pal[0]]],
showscale=False,
hoverinfo="text",
)
# Metrics bar chart --
metrics = {
"Sensitivity": result.sensitivity,
"Specificity": result.specificity,
"PPV": result.ppv,
"NPV": result.npv,
"Accuracy": result.accuracy,
"Youden J": result.youden,
}
m_labels = list(metrics.keys())
m_values = list(metrics.values())
colors = [pal[i % len(pal)] for i in range(len(m_labels))]
# Subplot layout
fig = make_subplots(
rows=1, cols=2,
column_widths=[0.4, 0.6],
subplot_titles=["Confusion Matrix", "Performance Metrics"],
)
fig.add_trace(cm_trace, row=1, col=1)
if not anim.enabled:
fig.add_trace(go.Bar(
x=m_labels, y=m_values,
marker_color=colors,
text=[f"{v:.3f}" for v in m_values],
textposition="outside",
showlegend=False,
), row=1, col=2)
else:
# Initial: first metric only
fig.add_trace(go.Bar(
x=m_labels[:1], y=m_values[:1],
marker_color=colors[:1],
text=[f"{m_values[0]:.3f}"],
textposition="outside",
showlegend=False,
), row=1, col=2)
# Frames: add one metric per frame
frames = []
for i in range(1, len(m_labels) + 1):
frames.append(go.Frame(
data=[
cm_trace,
go.Bar(
x=m_labels[:i], y=m_values[:i],
marker_color=colors[:i],
text=[f"{v:.3f}" for v in m_values[:i]],
textposition="outside",
),
],
name=str(i - 1),
))
fig.frames = frames
fig.update_layout(
updatemenus=_make_play_pause_buttons(anim)
)
fig.update_layout(
title=cfg.title or "Diagnostic Test Performance",
plot_bgcolor=bg,
paper_bgcolor=bg,
font=dict(size=cfg.font_size, color=font_c),
width=cfg.width,
height=cfg.height,
yaxis2=dict(range=[0, 1.15], showgrid=cfg.show_grid,
gridcolor=_GRID.get(cfg.theme, "#e5e5e5")),
)
return fig
# contingency
[docs]
def plot_contingency(
self,
result: Any,
config: Optional[PlotConfig] = None,
) -> Any:
"""
2x2 contingency table annotated heatmap with risk summary.
Static (no animation by design).
"""
import plotly.graph_objects as go
from plotly.subplots import make_subplots
cfg = self._resolve_config(config)
pal = _palette(cfg)
font_c = _FONT_COLOR.get(cfg.theme, "#222222")
bg = _BG.get(cfg.theme, "#ffffff")
# Extract table cells
if hasattr(result, "table"):
tbl = result.table
else:
tbl = result
a, b, c, d = tbl.a, tbl.b, tbl.c, tbl.d
total = a + b + c + d
cells = [[d, c], [b, a]]
texts = [
[f"TN (d)<br><b>{d}</b><br>{d/total:.1%}",
f"Exposed Non-cases (c)<br><b>{c}</b><br>{c/total:.1%}"],
[f"Unexposed Cases (b)<br><b>{b}</b><br>{b/total:.1%}",
f"Exposed Cases (a)<br><b>{a}</b><br>{a/total:.1%}"],
]
fig = make_subplots(
rows=1, cols=2,
column_widths=[0.55, 0.45],
subplot_titles=["2×2 Table", "Summary"],
)
fig.add_trace(go.Heatmap(
z=cells,
text=texts,
texttemplate="%{text}",
x=["Unexposed", "Exposed"],
y=["Non-cases", "Cases"],
colorscale=[[0, "#f0f7ff"], [1, pal[0]]],
showscale=False,
hoverinfo="text",
), row=1, col=1)
# Summary metrics
rr_result = tbl.risk_ratio()
or_result = tbl.odds_ratio()
chi2 = tbl.chi_square()
summary_labels = ["Risk Exposed", "Risk Unexposed", "Risk Ratio",
"Odds Ratio", "χ² p-value"]
summary_values = [
f"{tbl.risk_exposed:.3f}",
f"{tbl.risk_unexposed:.3f}",
f"{rr_result.estimate:.3f} ({rr_result.ci_lower:.3f}–{rr_result.ci_upper:.3f})",
f"{or_result.estimate:.3f} ({or_result.ci_lower:.3f}–{or_result.ci_upper:.3f})",
f"{chi2['p_value']:.4f}",
]
fig.add_trace(go.Table(
header=dict(
values=["<b>Measure</b>", "<b>Value</b>"],
fill_color=pal[0],
font=dict(color="white", size=cfg.font_size),
align="left",
),
cells=dict(
values=[summary_labels, summary_values],
fill_color=[[bg, bg] * 3],
font=dict(color=font_c, size=cfg.font_size - 1),
align="left",
height=28,
),
), row=1, col=2)
fig.update_layout(
title=cfg.title or "2×2 Contingency Table",
plot_bgcolor=bg,
paper_bgcolor=bg,
font=dict(size=cfg.font_size, color=font_c),
width=cfg.width,
height=cfg.height,
)
return fig
# save
[docs]
def save(
self,
fig: Any,
path: str,
fmt: OutputFormat = OutputFormat.PNG,
dpi: int = 150,
) -> str:
"""
Save a Plotly figure to disk.
Supports: PNG, SVG, PDF (via kaleido), HTML, JSON.
"""
import os
ext = f".{fmt.value}"
if not path.endswith(ext):
path = path + ext
path = os.path.abspath(path)
if fmt == OutputFormat.HTML:
fig.write_html(path)
elif fmt == OutputFormat.JSON:
with open(path, "w") as f:
f.write(fig.to_json())
else:
# PNG / SVG / PDF require kaleido
try:
fig.write_image(path, scale=dpi / 72)
except Exception as e:
raise RuntimeError(
f"Could not save as {fmt.value}. "
f"Install kaleido: pip install kaleido\n{e}"
)
return path
# Exports
__all__ = ["PlotlyPlotter"]