"""
viz/roc.py - ROC curve visualizations for Episia.
Public functions
----------------
plot_roc single ROC curve with AUC + optimal threshold
plot_roc_compare multiple ROC curves on the same axes (model comparison)
plot_precision_recall precision-recall curve (complement to ROC)
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
from .plotters import get_plotter, PlotConfig, AnimationConfig
from .themes.registry import get_palette
# plot_roc
[docs]
def plot_roc(
result: Any,
*,
title: str = "ROC Curve",
animate: bool = False,
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot a single ROC curve with AUC annotation and optimal threshold marker.
Args:
result: ROCResult from stats.diagnostic.roc_analysis().
title: Figure title.
animate: Trace the curve from (0,0) to (1,1) (Plotly only).
backend: 'plotly' (default) or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
Example::
from episia.stats.diagnostic import roc_analysis
from episia.viz.roc import plot_roc
result = roc_analysis(y_true, y_score)
plot_roc(result, title="Malaria RDT ROC").show()
"""
if config is None:
anim = (
AnimationConfig(
enabled=True,
anim_type=__import__(
"episia_main.viz.plotters.base_plotter",
fromlist=["AnimationType"]
).AnimationType.CONTINUOUS,
duration_ms=3000,
frame_ms=40,
)
if animate else AnimationConfig.default()
)
config = PlotConfig(
title=title,
theme=theme,
animation=anim,
)
return get_plotter(backend).plot_roc(result, config=config)
# plot_roc_compare
[docs]
def plot_roc_compare(
results: List[Any],
labels: Optional[List[str]] = None,
*,
title: str = "ROC Curve Comparison",
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Overlay multiple ROC curves for model comparison.
Args:
results: List of ROCResult objects.
labels: Display name for each curve (defaults to 'Model 1', 'Model 2'…).
title: Figure title.
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
Example::
fig = plot_roc_compare(
[result_lr, result_rf, result_xgb],
labels=["Logistic", "Random Forest", "XGBoost"],
)
"""
if not results:
raise ValueError("results list is empty.")
labels = labels or [f"Model {i+1}" for i in range(len(results))]
pal = get_palette(theme)
if config is None:
config = PlotConfig(title=title, theme=theme)
if backend == "plotly":
import plotly.graph_objects as go
from .plotters.plotly_plotter import _layout
fig = go.Figure()
# Reference diagonal
fig.add_trace(go.Scatter(
x=[0, 1], y=[0, 1],
mode="lines",
line=dict(color="#aaaaaa", width=1, dash="dot"),
showlegend=False, hoverinfo="skip",
))
for i, (res, label) in enumerate(zip(results, labels)):
fig.add_trace(go.Scatter(
x=list(res.fpr), y=list(res.tpr),
mode="lines",
name=f"{label} (AUC={res.auc:.3f})",
line=dict(color=pal[i % len(pal)], width=2.2),
))
fig.update_layout(_layout(
config,
xaxis=dict(title="False Positive Rate", range=[0, 1],
showgrid=config.show_grid),
yaxis=dict(title="True Positive Rate", range=[0, 1.02],
showgrid=config.show_grid),
))
return fig
else:
import matplotlib.pyplot as plt
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
size = min(config.width, config.height) / 100
fig, ax = plt.subplots(figsize=(size, size), facecolor="white")
ax.plot([0, 1], [0, 1], color="#aaaaaa", linewidth=1,
linestyle="--", label="Random")
for i, (res, label) in enumerate(zip(results, labels)):
ax.fill_between(res.fpr, res.tpr, alpha=0.05,
color=pal[i % len(pal)])
ax.plot(res.fpr, res.tpr,
color=pal[i % len(pal)], linewidth=2.2,
label=f"{label} (AUC={res.auc:.3f})")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1.02)
ax.set_aspect("equal")
ax.set_xlabel("False Positive Rate", fontsize=config.font_size)
ax.set_ylabel("True Positive Rate", fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
ax.legend(fontsize=config.font_size - 1, loc="lower right")
fig.tight_layout()
return fig
# plot_precision_recall
[docs]
def plot_precision_recall(
y_true: Any,
y_score: Any,
*,
label: str = "Model",
title: str = "Precision-Recall Curve",
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot a precision-recall curve.
More informative than ROC when classes are imbalanced.
Args:
y_true: True binary labels.
y_score: Predicted probabilities or scores.
label: Curve label.
title: Figure title.
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
"""
from sklearn.metrics import precision_recall_curve, average_precision_score
y_true = np.asarray(y_true)
y_score = np.asarray(y_score)
precision, recall, _ = precision_recall_curve(y_true, y_score)
ap = average_precision_score(y_true, y_score)
baseline = y_true.mean()
if config is None:
config = PlotConfig(title=title, theme=theme)
pal = get_palette(theme)
if backend == "plotly":
import plotly.graph_objects as go
from .plotters.plotly_plotter import _layout
fig = go.Figure()
# Baseline
fig.add_trace(go.Scatter(
x=[0, 1], y=[baseline, baseline],
mode="lines",
line=dict(color="#aaaaaa", width=1, dash="dot"),
showlegend=False, hoverinfo="skip",
))
fig.add_trace(go.Scatter(
x=list(recall), y=list(precision),
mode="lines",
fill="tozeroy",
fillcolor=f"rgba({_hex_to_rgb(pal[0])},0.08)",
line=dict(color=pal[0], width=2.5),
name=f"{label} (AP={ap:.3f})",
))
fig.update_layout(_layout(
config,
xaxis=dict(title="Recall", range=[0, 1], showgrid=config.show_grid),
yaxis=dict(title="Precision", range=[0, 1.02],
showgrid=config.show_grid),
annotations=[dict(
x=0.97, y=0.05, xref="paper", yref="paper",
text=f"AP = {ap:.3f}",
showarrow=False,
font=dict(size=config.font_size + 1),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="#cccccc", borderwidth=1, align="right",
)],
))
return fig
else:
import matplotlib.pyplot as plt
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
fig, ax = plt.subplots(
figsize=(config.width / 100, config.height / 100),
facecolor="white",
)
ax.axhline(baseline, color="#aaaaaa", linewidth=1, linestyle="--",
label="Baseline")
ax.fill_between(recall, precision, alpha=0.08, color=pal[0])
ax.plot(recall, precision, color=pal[0], linewidth=2.5,
label=f"{label} (AP={ap:.3f})")
ax.text(0.97, 0.05, f"AP = {ap:.3f}",
transform=ax.transAxes, ha="right", va="bottom",
fontsize=config.font_size - 1,
bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
edgecolor="#cccccc", alpha=0.8))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1.05)
ax.set_xlabel("Recall", fontsize=config.font_size)
ax.set_ylabel("Precision", fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
ax.legend(fontsize=config.font_size - 1)
fig.tight_layout()
return fig
def _hex_to_rgb(h: str) -> str:
h = h.lstrip("#")
return f"{int(h[0:2],16)},{int(h[2:4],16)},{int(h[4:6],16)}"
__all__ = ["plot_roc", "plot_roc_compare", "plot_precision_recall"]