"""
viz/curves.py - Epidemic curve and trend visualizations for Episia.
Public functions
----------------
plot_epicurve bar chart of cases over time, optional trend overlay
plot_trend trend line only (linear / LOESS / moving average)
plot_incidence incidence rate over time with optional CI band
plot_doubling log-scale growth with doubling time annotation
All functions accept both the internal EpidemicCurve / TimeSeriesResult
objects from stats.time_series and the unified TimeSeriesResult from
api.results. Raw numpy arrays are also accepted for convenience.
Backend selection
-----------------
backend="plotly" (default) interactive, web-ready
backend="matplotlib" publication quality, static
"""
from __future__ import annotations
from typing import Any, List, Optional, Union
import numpy as np
from .plotters import get_plotter, PlotConfig, AnimationConfig, AnimationType
from .plotters.base_plotter import OutputFormat
# Internal helpers
def _coerce_times_values(
source: Any,
times: Optional[np.ndarray],
values: Optional[np.ndarray],
) -> tuple:
"""
Extract (times, values, trend, trend_method) from various input types.
Accepts:
- api.results.TimeSeriesResult
- stats.time_series.EpidemicCurve
- stats.time_series.TimeSeriesResult
- raw arrays (times + values required)
"""
trend = None
trend_method = None
doubling_time = None
if source is not None:
# api.results.TimeSeriesResult
if hasattr(source, "times") and hasattr(source, "values"):
times = source.times
values = source.values
trend = getattr(source, "trend", None)
trend_method = getattr(source, "trend_method", None)
doubling_time = getattr(source, "doubling_time", None)
# stats.time_series.EpidemicCurve
elif hasattr(source, "dates") and hasattr(source, "counts"):
times = source.dates
values = source.counts
# stats.time_series.TimeSeriesResult
elif hasattr(source, "dates") and hasattr(source, "observed"):
times = source.dates
values = source.observed
trend = getattr(source, "trend", None)
trend_method = getattr(source, "method", None)
if times is None or values is None:
raise ValueError(
"Provide either a result object or explicit times and values arrays."
)
times = np.asarray(times)
values = np.asarray(values, dtype=float)
return times, values, trend, trend_method, doubling_time
# plot_epicurve
[docs]
def plot_epicurve(
result: Any = None,
*,
times: Optional[np.ndarray] = None,
values: Optional[np.ndarray] = None,
title: str = "Epidemic Curve",
xlabel: str = "Period",
ylabel: str = "Cases",
backend: str = "plotly",
theme: str = "scientific",
animate: bool = False,
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot an epidemic curve (cases over time) as a bar chart.
Args:
result: TimeSeriesResult, EpidemicCurve, or None (use times/values).
times: Array of time labels (used if result is None).
values: Array of case counts (used if result is None).
title: Figure title.
xlabel: X-axis label.
ylabel: Y-axis label.
backend: 'plotly' (default) or 'matplotlib'.
theme: Theme name.
animate: If True, bars build up frame by frame (Plotly only).
config: Full PlotConfig override supersedes individual args.
Returns:
plotly.graph_objects.Figure or matplotlib.figure.Figure
Examples::
# From a TimeSeriesResult
fig = plot_epicurve(result, title="Ebola 2014 Guinea")
fig.show()
# From raw arrays
fig = plot_epicurve(times=weeks, values=counts, animate=True)
# Publication export
fig = plot_epicurve(result, backend="matplotlib")
fig.savefig("figure1.pdf", dpi=300, bbox_inches="tight")
"""
t, v, trend, trend_method, _ = _coerce_times_values(result, times, values)
# Build a lightweight proxy that plotters can consume
class _Proxy:
pass
proxy = _Proxy()
proxy.times = t
proxy.values = v
proxy.trend = trend
proxy.trend_method = trend_method
if config is None:
if animate:
anim_cfg = AnimationConfig.frame_buildup(len(t))
# Downsample data to match capped frame count
step = max(1, len(t) // AnimationConfig.MAX_ANIMATION_FRAMES)
if step > 1:
t = t[::step]
v = v[::step]
if trend is not None:
trend = trend[::step]
else:
anim_cfg = AnimationConfig.default()
config = PlotConfig(
title=title,
xlabel=xlabel,
ylabel=ylabel,
theme=theme,
animation=anim_cfg,
)
plotter = get_plotter(backend)
return plotter.plot_epicurve(proxy, config=config)
# plot_trend
[docs]
def plot_trend(
result: Any = None,
*,
times: Optional[np.ndarray] = None,
values: Optional[np.ndarray] = None,
show_observed: bool = True,
title: str = "Trend Analysis",
xlabel: str = "Period",
ylabel: str = "Value",
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot a trend line with optional observed values overlay.
Args:
result: TimeSeriesResult with a trend array, or raw input.
show_observed: If True, plot observed values as scatter points.
title: Figure title.
xlabel / ylabel: Axis labels.
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
Examples::
fig = plot_trend(result, show_observed=True, title="Weekly incidence trend")
"""
t, v, trend, trend_method, _ = _coerce_times_values(result, times, values)
if config is None:
config = PlotConfig(
title=title, xlabel=xlabel, ylabel=ylabel, theme=theme,
)
pal = _get_palette(theme)
if backend == "plotly":
import plotly.graph_objects as go
fig = go.Figure()
if show_observed:
fig.add_trace(go.Scatter(
x=list(t), y=list(v),
mode="markers",
marker=dict(color=pal[0], size=6, opacity=0.6),
name="Observed",
))
if trend is not None:
fig.add_trace(go.Scatter(
x=list(t), y=list(trend),
mode="lines",
line=dict(color=pal[1], width=2.5),
name=trend_method or "Trend",
))
else:
# Fallback: linear regression
_t_num = np.arange(len(t), dtype=float)
slope, intercept = np.polyfit(_t_num, v, 1)
fitted = slope * _t_num + intercept
fig.add_trace(go.Scatter(
x=list(t), y=list(fitted),
mode="lines",
line=dict(color=pal[1], width=2.5, dash="dash"),
name="Linear trend",
))
from .plotters.plotly_plotter import _layout
fig.update_layout(_layout(config))
return fig
else:
import matplotlib.pyplot as plt
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
fig, ax = plt.subplots(figsize=(config.width / 100, config.height / 100),
facecolor="white")
if show_observed:
ax.scatter(range(len(t)), v, color=pal[0], s=30, alpha=0.6,
label="Observed", zorder=3)
trend_y = trend if trend is not None else np.polyval(
np.polyfit(np.arange(len(t)), v, 1), np.arange(len(t))
)
ax.plot(range(len(t)), trend_y, color=pal[1], linewidth=2.5,
label=trend_method or "Trend")
ax.set_xticks(range(len(t)))
ax.set_xticklabels([str(x) for x in t], rotation=45, ha="right",
fontsize=config.font_size - 2)
ax.set_xlabel(xlabel, fontsize=config.font_size)
ax.set_ylabel(ylabel, fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
if config.show_legend:
ax.legend(fontsize=config.font_size - 1)
fig.tight_layout()
return fig
# plot_incidence
[docs]
def plot_incidence(
result: Any = None,
*,
times: Optional[np.ndarray] = None,
rates: Optional[np.ndarray] = None,
ci_lower: Optional[np.ndarray] = None,
ci_upper: Optional[np.ndarray] = None,
per: int = 100_000,
title: str = "Incidence Rate",
xlabel: str = "Period",
ylabel: Optional[str] = None,
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot incidence rate over time with optional confidence interval band.
Args:
result: TimeSeriesResult or EpidemicCurve (rates taken from values).
times: Time labels array.
rates: Incidence rate array.
ci_lower: Lower CI bound array (optional).
ci_upper: Upper CI bound array (optional).
per: Population denominator for ylabel label (default 100 000).
title: Figure title.
xlabel: X-axis label.
ylabel: Y-axis label (auto-generated if None).
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
"""
t, v, _, _, _ = _coerce_times_values(result, times, rates)
y_label = ylabel or f"Rate per {per:,}"
if config is None:
config = PlotConfig(
title=title, xlabel=xlabel, ylabel=y_label, theme=theme,
)
pal = _get_palette(theme)
if backend == "plotly":
import plotly.graph_objects as go
fig = go.Figure()
# CI band
if ci_lower is not None and ci_upper is not None:
_lo = np.asarray(ci_lower)
_hi = np.asarray(ci_upper)
fig.add_trace(go.Scatter(
x=list(t) + list(t[::-1]),
y=list(_hi) + list(_lo[::-1]),
fill="toself",
fillcolor=f"rgba({_hex_to_rgb(pal[0])},0.15)",
line=dict(color="rgba(0,0,0,0)"),
showlegend=False,
hoverinfo="skip",
))
fig.add_trace(go.Scatter(
x=list(t), y=list(v),
mode="lines+markers",
line=dict(color=pal[0], width=2.5),
marker=dict(size=5),
name=y_label,
))
from .plotters.plotly_plotter import _layout
fig.update_layout(_layout(config))
return fig
else:
import matplotlib.pyplot as plt
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
fig, ax = plt.subplots(figsize=(config.width / 100, config.height / 100),
facecolor="white")
x = np.arange(len(t))
if ci_lower is not None and ci_upper is not None:
ax.fill_between(x, ci_lower, ci_upper, alpha=0.15,
color=pal[0], label="95% CI")
ax.plot(x, v, color=pal[0], linewidth=2.5,
marker="o", markersize=4, label=y_label)
ax.set_xticks(x)
ax.set_xticklabels([str(s) for s in t], rotation=45, ha="right",
fontsize=config.font_size - 2)
ax.set_xlabel(xlabel, fontsize=config.font_size)
ax.set_ylabel(y_label, fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
if config.show_legend:
ax.legend(fontsize=config.font_size - 1)
fig.tight_layout()
return fig
# plot_doubling
[docs]
def plot_doubling(
result: Any = None,
*,
times: Optional[np.ndarray] = None,
values: Optional[np.ndarray] = None,
doubling_time: Optional[float] = None,
title: str = "Growth Curve",
xlabel: str = "Period",
ylabel: str = "Cases (log scale)",
backend: str = "plotly",
theme: str = "scientific",
config: Optional[PlotConfig] = None,
) -> Any:
"""
Plot cumulative cases on a log scale with doubling time annotation.
Args:
result: TimeSeriesResult (doubling_time read from object if present).
times: Time labels.
values: Cumulative case counts.
doubling_time: Doubling time in periods (overrides result attribute).
title / xlabel / ylabel: Labels.
backend: 'plotly' or 'matplotlib'.
theme: Theme name.
config: Full PlotConfig override.
Returns:
Figure object.
"""
t, v, _, _, dt_from_result = _coerce_times_values(result, times, values)
dt = doubling_time or dt_from_result
if config is None:
config = PlotConfig(
title=title, xlabel=xlabel, ylabel=ylabel, theme=theme,
)
pal = _get_palette(theme)
# Safe log transform
v_log = np.where(v > 0, np.log2(v), np.nan)
if backend == "plotly":
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
x=list(t), y=list(v),
mode="lines+markers",
line=dict(color=pal[0], width=2.5),
marker=dict(size=5),
name="Cases",
))
annotations = []
if dt is not None:
annotations.append(dict(
x=0.02, y=0.95, xref="paper", yref="paper",
text=f"Doubling time: {dt:.1f} periods",
showarrow=False,
font=dict(size=config.font_size, color="#333333"),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="#cccccc",
borderwidth=1,
))
from .plotters.plotly_plotter import _layout
fig.update_layout(_layout(
config,
yaxis=dict(
type="log",
title=ylabel,
showgrid=config.show_grid,
),
annotations=annotations,
))
return fig
else:
import matplotlib.pyplot as plt
from .themes.registry import apply_mpl_theme
apply_mpl_theme(theme)
fig, ax = plt.subplots(figsize=(config.width / 100, config.height / 100),
facecolor="white")
valid = v > 0
ax.semilogy(np.arange(len(t))[valid], v[valid],
color=pal[0], linewidth=2.5,
marker="o", markersize=4, label="Cases")
if dt is not None:
ax.text(0.02, 0.95,
f"Doubling time: {dt:.1f} periods",
transform=ax.transAxes, fontsize=config.font_size - 1,
verticalalignment="top",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
edgecolor="#cccccc", alpha=0.8))
ax.set_xticks(np.arange(len(t)))
ax.set_xticklabels([str(s) for s in t], rotation=45, ha="right",
fontsize=config.font_size - 2)
ax.set_xlabel(xlabel, fontsize=config.font_size)
ax.set_ylabel(ylabel, fontsize=config.font_size)
ax.set_title(title, fontsize=config.font_size + 2, fontweight="bold")
fig.tight_layout()
return fig
# Internal colour helpers
def _get_palette(theme: str) -> List[str]:
from .themes.registry import get_palette
return get_palette(theme)
def _hex_to_rgb(hex_color: str) -> str:
"""Convert '#rrggbb' to 'r,g,b' string for rgba()."""
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return f"{r},{g},{b}"
# Exports
__all__ = [
"plot_epicurve",
"plot_trend",
"plot_incidence",
"plot_doubling",
]