Source code for episia.viz.forest
"""
viz/forest.py - Forest plot visualizations for Episia.
Public functions
----------------
plot_forest stratified / regression forest plot
plot_meta_forest meta-analysis style with heterogeneity stats
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .plotters import get_plotter, PlotConfig, AnimationConfig, AnimationType
from .themes.registry import get_palette
# _collect_rows helper
def _collect_rows(result: Any) -> Tuple[List[Dict], float]:
"""Extract rows and null_value from various result types."""
rows: List[Dict] = []
null_value = 1.0
# StratifiedResult (api.results)
if hasattr(result, "stratum_results") and result.stratum_results:
for s in result.stratum_results:
rows.append(dict(
label=s.metadata.get("label", s.measure),
est=s.estimate,
lo=s.ci.lower,
hi=s.ci.upper,
p=s.p_value,
n=s.n_total,
pooled=False,
))
rows.append(dict(
label="Pooled (MH)",
est=result.mh_estimate,
lo=result.ci.lower,
hi=result.ci.upper,
p=result.p_value,
n=None,
pooled=True,
))
# RegressionResult (api.results)
elif hasattr(result, "coefficients"):
null_value = 0.0
for var, coef in result.coefficients.items():
lo, hi = result.ci_table.get(var, (coef, coef))
rows.append(dict(
label=var,
est=coef,
lo=lo,
hi=hi,
p=result.p_values.get(var),
n=None,
pooled=False,
))
# Bare AssociationResult or single-row fallback
else:
rows.append(dict(
label=getattr(result, "measure", "estimate"),
est=result.estimate,
lo=result.ci.lower,
hi=result.ci.upper,
p=getattr(result, "p_value", None),
n=getattr(result, "n_total", None),
pooled=False,
))
null_value = getattr(result, "null_value", 1.0)
return rows, null_value
def _p_str(p: Optional[float]) -> str:
if p is None:
return ""
return "p<0.001" if p < 0.001 else f"p={p:.3f}"
# plot_forest
[docs]
def plot_forest(
result: Any,
*,
title: str = "Forest Plot",
xlabel: str = "Estimate",
animate: bool = False,
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Forest plot for stratified or regression results.
Args:
result: StratifiedResult, RegressionResult, or AssociationResult.
title: Figure title.
xlabel: X-axis label.
animate: Rows appear one by one (Plotly only).
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
Example::
from episia.viz.forest import plot_forest
plot_forest(stratified_result, title="Stratified OR by age group").show()
"""
if config is None:
anim = (
AnimationConfig(enabled=True, anim_type=AnimationType.FRAME_BY_FRAME,
frame_ms=300)
if animate else AnimationConfig.default()
)
config = PlotConfig(
title=title, xlabel=xlabel, theme=theme, animation=anim,
)
return get_plotter(backend).plot_forest(result, config=config)
# plot_meta_forest
[docs]
def plot_meta_forest(
estimates: List[float],
ci_lowers: List[float],
ci_uppers: List[float],
labels: List[str],
*,
weights: Optional[List[float]] = None,
pooled_estimate: Optional[float] = None,
pooled_ci: Optional[Tuple[float, float]] = None,
i_squared: Optional[float] = None,
tau_squared: Optional[float] = None,
p_heterogeneity: Optional[float] = None,
null_value: float = 1.0,
title: str = "Meta-Analysis Forest Plot",
xlabel: str = "Effect Estimate",
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Meta-analysis style forest plot with heterogeneity statistics.
Marker sizes are proportional to study weights.
I² and τ² annotations are included when provided.
Args:
estimates: Per-study point estimates.
ci_lowers: Per-study lower CI bounds.
ci_uppers: Per-study upper CI bounds.
labels: Study / stratum labels.
weights: Relative weights (e.g. 1/variance). Auto-normalised.
pooled_estimate: Pooled (diamond) point estimate.
pooled_ci: (lower, upper) for pooled estimate.
i_squared: I² heterogeneity statistic (%).
tau_squared: τ² between-study variance.
p_heterogeneity: P-value for Q heterogeneity test.
null_value: Null reference line position (1.0 for ratios).
title / xlabel: Labels.
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
"""
n = len(estimates)
if not (len(ci_lowers) == len(ci_uppers) == len(labels) == n):
raise ValueError("estimates, ci_lowers, ci_uppers, labels must all have same length.")
# Normalise weights → marker sizes
if weights is not None:
w = np.asarray(weights, dtype=float)
w_norm = (w / w.max()) * 12 + 4 # sizes 4–16
else:
w_norm = np.full(n, 8.0)
if config is None:
config = PlotConfig(title=title, xlabel=xlabel, theme=theme,
height=max(400, n * 32 + 150))
pal = get_palette(theme)
# Heterogeneity annotation text
het_lines = []
if i_squared is not None:
het_lines.append(f"I² = {i_squared:.1f}%")
if tau_squared is not None:
het_lines.append(f"τ² = {tau_squared:.4f}")
if p_heterogeneity is not None:
het_lines.append(_p_str(p_heterogeneity) + " (heterogeneity)")
het_text = " | ".join(het_lines)
y_positions = list(range(n - 1, -1, -1))
if backend == "plotly":
import plotly.graph_objects as go
from .plotters.plotly_plotter import _layout, _FONT_COLOR
fc = _FONT_COLOR.get(theme, "#222222")
fig = go.Figure()
# Null line
fig.add_vline(x=null_value, line=dict(color="#999999", width=1,
dash="dot"))
for i in range(n):
y = y_positions[i]
# CI whisker
fig.add_trace(go.Scatter(
x=[ci_lowers[i], ci_uppers[i]], y=[y, y],
mode="lines",
line=dict(color=pal[0], width=1.8),
showlegend=False, hoverinfo="skip",
))
# Point estimate
fig.add_trace(go.Scatter(
x=[estimates[i]], y=[y],
mode="markers",
marker=dict(color=pal[0], size=float(w_norm[i]),
symbol="square"),
name=labels[i],
hovertemplate=(
f"<b>{labels[i]}</b><br>"
f"Estimate: {estimates[i]:.3f}<br>"
f"95% CI: [{ci_lowers[i]:.3f}, {ci_uppers[i]:.3f}]"
"<extra></extra>"
),
showlegend=False,
))
# Pooled diamond
if pooled_estimate is not None and pooled_ci is not None:
y_pool = -1
lo, hi = pooled_ci
mid_h = 0.35
diamond_x = [lo, pooled_estimate, hi, pooled_estimate, lo]
diamond_y = [y_pool, y_pool + mid_h, y_pool,
y_pool - mid_h, y_pool]
fig.add_trace(go.Scatter(
x=diamond_x, y=diamond_y,
mode="lines",
fill="toself",
fillcolor=pal[1],
line=dict(color=pal[1], width=1),
name=f"Pooled: {pooled_estimate:.3f} [{lo:.3f}, {hi:.3f}]",
))
annotations = []
if het_text:
annotations.append(dict(
x=0.5, y=-0.12, xref="paper", yref="paper",
text=het_text, showarrow=False,
font=dict(size=config.font_size - 1, color=fc),
align="center",
))
y_min = (-1.8 if pooled_estimate is not None else -0.5)
fig.update_layout(_layout(
config,
yaxis=dict(
tickvals=y_positions,
ticktext=labels,
range=[y_min, n - 0.3],
showgrid=False, zeroline=False, color=fc,
),
annotations=annotations,
))
return fig
else:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
fig_h = max(4.0, n * 0.45 + 1.8)
fig, ax = plt.subplots(
figsize=(config.width / 100, fig_h), facecolor="white",
)
ax.axvline(null_value, color="#999999", linewidth=1,
linestyle="--", alpha=0.7)
for i in range(n):
y = y_positions[i]
ax.plot([ci_lowers[i], ci_uppers[i]], [y, y],
color=pal[0], linewidth=1.8)
ax.plot(estimates[i], y, "s",
color=pal[0], markersize=float(w_norm[i]) * 0.7,
zorder=5)
# Pooled diamond
if pooled_estimate is not None and pooled_ci is not None:
lo, hi = pooled_ci
y_pool = -1
diamond = mpatches.FancyArrow(
0, 0, 0, 0 # placeholder use polygon instead
)
poly_x = [lo, pooled_estimate, hi, pooled_estimate]
poly_y = [y_pool, y_pool + 0.35, y_pool, y_pool - 0.35]
ax.fill(poly_x, poly_y, color=pal[1], zorder=5)
ax.plot(poly_x + [poly_x[0]], poly_y + [poly_y[0]],
color=pal[1], linewidth=1)
ax.set_yticks(y_positions)
ax.set_yticklabels(labels, fontsize=config.font_size - 1)
ax.set_xlabel(xlabel, fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
if pooled_estimate is not None:
ax.set_ylim(-1.8, n - 0.3)
else:
ax.set_ylim(-0.5, n - 0.3)
if het_text:
fig.text(0.5, -0.02, het_text,
ha="center", fontsize=config.font_size - 2,
style="italic", transform=ax.transAxes)
ax.yaxis.grid(False)
fig.tight_layout()
return fig
__all__ = ["plot_forest", "plot_meta_forest"]