"""
data/surveillance.py - Epidemiological surveillance data utilities.
Tools for ingesting, cleaning, aggregating, and alerting on
routine surveillance data designed for the Burkina Faso / francophone
African public health context (SNIS, DHIS2-compatible CSV exports).
Public classes
--------------
SurveillanceDataset structured weekly/daily case counts per site/disease
AlertEngine threshold-based and statistical alert detection
Public functions
----------------
from_dhis2_csv() load DHIS2 export CSV
from_weekly_bulletin() parse standard weekly bulletin table
aggregate_by() temporal or spatial aggregation
compute_attack_rate() attack rate per stratum
endemic_channel() historical percentile envelope (alert zones)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
# SurveillanceDataset
[docs]
class SurveillanceDataset:
"""
Structured surveillance case count dataset.
Wraps a pandas DataFrame with columns:
date / week / period time axis
district / site spatial unit (optional)
disease disease or syndrome name
cases integer case count
deaths integer death count (optional)
population population at risk (optional)
Built from CSV, DHIS2 exports, or a plain DataFrame.
Example::
from episia.data.surveillance import SurveillanceDataset
ds = SurveillanceDataset.from_csv("meningite_2024.csv",
date_col="semaine",
cases_col="cas")
print(ds.summary())
ds.epicurve().plot().show()
alerts = ds.alert_engine().run()
"""
[docs]
def __init__(self, df, *, date_col: str = "date",
cases_col: str = "cases", deaths_col: Optional[str] = None,
district_col: Optional[str] = None,
disease_col: Optional[str] = None,
population_col: Optional[str] = None):
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required. pip install pandas")
self._df = df.copy()
self.date_col = date_col
self.cases_col = cases_col
self.deaths_col = deaths_col
self.district_col = district_col
self.disease_col = disease_col
self.population_col = population_col
# Ensure date column is datetime
if date_col in self._df.columns:
self._df[date_col] = pd.to_datetime(self._df[date_col],
errors="coerce")
# Constructors
[docs]
@classmethod
def from_csv(
cls,
path: Union[str, Path],
date_col: str = "date",
cases_col: str = "cases",
deaths_col: Optional[str] = None,
district_col: Optional[str] = None,
disease_col: Optional[str] = None,
population_col: Optional[str] = None,
**read_kwargs,
) -> "SurveillanceDataset":
"""
Load from CSV file.
Args:
path: Path to CSV file.
date_col: Column name for date / week.
cases_col: Column name for case counts.
deaths_col: Column name for deaths (optional).
district_col: Column name for district / site (optional).
disease_col: Column name for disease / syndrome (optional).
population_col: Column for population at risk (optional).
**read_kwargs: Passed to pd.read_csv.
Returns:
SurveillanceDataset.
"""
import pandas as pd
df = pd.read_csv(path, **read_kwargs)
return cls(df, date_col=date_col, cases_col=cases_col,
deaths_col=deaths_col, district_col=district_col,
disease_col=disease_col, population_col=population_col)
[docs]
@classmethod
def from_dict(
cls,
data: Dict[str, List],
**kwargs,
) -> "SurveillanceDataset":
"""Create from a plain dict of lists."""
import pandas as pd
return cls(pd.DataFrame(data), **kwargs)
[docs]
@classmethod
def from_dataframe(
cls,
df,
**kwargs,
) -> "SurveillanceDataset":
"""Wrap an existing DataFrame."""
return cls(df, **kwargs)
# Properties
@property
def df(self):
return self._df
@property
def n_records(self) -> int:
return len(self._df)
@property
def total_cases(self) -> int:
return int(self._df[self.cases_col].sum())
@property
def total_deaths(self) -> Optional[int]:
if self.deaths_col and self.deaths_col in self._df.columns:
return int(self._df[self.deaths_col].sum())
return None
@property
def cfr(self) -> Optional[float]:
"""Case fatality rate = total_deaths / total_cases."""
d = self.total_deaths
c = self.total_cases
if d is not None and c > 0:
return d / c
return None
@property
def date_range(self) -> Tuple[Any, Any]:
col = self._df[self.date_col]
return col.min(), col.max()
@property
def districts(self) -> List[str]:
if self.district_col and self.district_col in self._df.columns:
return sorted(self._df[self.district_col].dropna().unique().tolist())
return []
@property
def diseases(self) -> List[str]:
if self.disease_col and self.disease_col in self._df.columns:
return sorted(self._df[self.disease_col].dropna().unique().tolist())
return []
# Filtering
[docs]
def filter_district(self, district: str) -> "SurveillanceDataset":
"""Return a new dataset filtered to a single district."""
if not self.district_col:
raise ValueError("No district_col defined.")
mask = self._df[self.district_col] == district
return SurveillanceDataset(
self._df[mask], date_col=self.date_col,
cases_col=self.cases_col, deaths_col=self.deaths_col,
district_col=self.district_col, disease_col=self.disease_col,
population_col=self.population_col,
)
[docs]
def filter_disease(self, disease: str) -> "SurveillanceDataset":
"""Return a new dataset filtered to a single disease."""
if not self.disease_col:
raise ValueError("No disease_col defined.")
mask = self._df[self.disease_col] == disease
return SurveillanceDataset(
self._df[mask], date_col=self.date_col,
cases_col=self.cases_col, deaths_col=self.deaths_col,
district_col=self.district_col, disease_col=self.disease_col,
population_col=self.population_col,
)
[docs]
def filter_date(self, start: Any = None,
end: Any = None) -> "SurveillanceDataset":
"""Filter to a date range (inclusive)."""
import pandas as pd
df = self._df
if start is not None:
df = df[df[self.date_col] >= pd.to_datetime(start)]
if end is not None:
df = df[df[self.date_col] <= pd.to_datetime(end)]
return SurveillanceDataset(
df, date_col=self.date_col, cases_col=self.cases_col,
deaths_col=self.deaths_col, district_col=self.district_col,
disease_col=self.disease_col,
population_col=self.population_col,
)
# Aggregation
[docs]
def aggregate(
self,
freq: str = "W",
group_by: Optional[List[str]] = None,
):
"""
Aggregate cases by time frequency and optional grouping columns.
Args:
freq: Pandas offset alias ('D'=daily, 'W'=weekly, 'ME'=monthly).
group_by: Additional columns to group by (district, disease…).
Returns:
pandas DataFrame with aggregated counts.
"""
import pandas as pd
df = self._df.copy()
# Normalize freq alias for to_period compatibility
_freq_map = {'ME': 'M', 'QE': 'Q', 'YE': 'Y', 'h': 'H'}
_p_freq = _freq_map.get(freq, freq)
df["_period"] = df[self.date_col].dt.to_period(_p_freq).dt.start_time
agg_cols = {"_period": "first", self.cases_col: "sum"}
if self.deaths_col and self.deaths_col in df.columns:
agg_cols[self.deaths_col] = "sum"
if self.population_col and self.population_col in df.columns:
agg_cols[self.population_col] = "first"
keys = ["_period"]
if group_by:
keys += [c for c in group_by if c in df.columns]
result = (df.groupby(keys, as_index=False)
.agg({c: agg_cols[c] for c in agg_cols if c != "_period"}))
result = result.rename(columns={"_period": "period"})
return result.sort_values("period").reset_index(drop=True)
# Epidemiological metrics
[docs]
def attack_rate(
self,
population: Optional[int] = None,
per: int = 100_000,
) -> float:
"""
Compute overall attack rate.
Args:
population: Population denominator (uses population_col if None).
per: Rate denominator (default 100,000).
Returns:
Attack rate per `per` population.
"""
if population is None:
if self.population_col and self.population_col in self._df.columns:
population = int(self._df[self.population_col].iloc[0])
else:
raise ValueError(
"population argument required when population_col is not set."
)
return self.total_cases / population * per
[docs]
def weekly_attack_rates(self, population: int, per: int = 100_000):
"""
Compute weekly attack rates.
Args:
population: Population at risk.
per: Rate denominator.
Returns:
pandas DataFrame with columns: period, cases, attack_rate.
"""
agg = self.aggregate(freq="W")
agg["attack_rate"] = agg[self.cases_col] / population * per
return agg
[docs]
def endemic_channel(
self,
historical_years: Optional[List[int]] = None,
percentiles: Tuple[float, float, float] = (25, 50, 75),
) -> Dict[str, Any]:
"""
Compute the endemic channel (historical percentile envelope).
Groups by ISO week number across historical years.
Returns the percentile bands used for alert zone classification.
Args:
historical_years: Years to include (all years if None).
percentiles: (low, median, high) percentiles.
Returns:
Dict with keys: 'weeks', 'p_low', 'p_mid', 'p_high'.
"""
import pandas as pd
df = self._df.copy()
df["_year"] = df[self.date_col].dt.year
df["_week"] = df[self.date_col].dt.isocalendar().week.astype(int)
if historical_years:
df = df[df["_year"].isin(historical_years)]
grouped = df.groupby("_week")[self.cases_col]
p_low = grouped.quantile(percentiles[0] / 100)
p_mid = grouped.quantile(percentiles[1] / 100)
p_high = grouped.quantile(percentiles[2] / 100)
return {
"weeks": p_low.index.tolist(),
"p_low": p_low.values,
"p_mid": p_mid.values,
"p_high": p_high.values,
"percentiles": percentiles,
}
# Export to Episia viz
[docs]
def to_timeseries_result(self):
"""
Convert to api.results.TimeSeriesResult for viz integration.
Returns:
TimeSeriesResult ready for plot_epicurve().
"""
from ..api.results import TimeSeriesResult
agg = self.aggregate(freq="W")
times = agg["period"].dt.strftime("%Y-W%W").values
values = agg[self.cases_col].values.astype(float)
return TimeSeriesResult(times=times, values=values)
# Summary
[docs]
def summary(self) -> Dict[str, Any]:
"""Return a summary statistics dict."""
start, end = self.date_range
s: Dict[str, Any] = {
"n_records": self.n_records,
"total_cases": self.total_cases,
"date_start": str(start),
"date_end": str(end),
}
if self.total_deaths is not None:
s["total_deaths"] = self.total_deaths
s["cfr"] = self.cfr
if self.districts:
s["n_districts"] = len(self.districts)
s["districts"] = self.districts[:10]
if self.diseases:
s["diseases"] = self.diseases
return s
def __repr__(self) -> str:
start, end = self.date_range
return (
f"SurveillanceDataset("
f"n={self.n_records}, "
f"cases={self.total_cases:,}, "
f"{start} → {end})"
)
# AlertEngine
[docs]
@dataclass
class Alert:
"""A single surveillance alert."""
period: Any
value: float
threshold: float
kind: str # 'threshold', 'zscore', 'endemic_channel'
severity: str # 'warning', 'alert', 'epidemic'
district: Optional[str] = None
disease: Optional[str] = None
message: str = ""
[docs]
class AlertEngine:
"""
Threshold-based and statistical alert detection for surveillance data.
Example::
engine = AlertEngine(dataset)
alerts = engine.run(
threshold=10,
zscore_threshold=2.0,
use_endemic_channel=True,
)
for a in alerts:
print(a.period, a.severity, a.message)
"""
[docs]
def __init__(self, dataset: SurveillanceDataset):
self.dataset = dataset
[docs]
def run(
self,
threshold: Optional[float] = None,
zscore_threshold: float = 2.0,
use_endemic_channel: bool = False,
historical_years: Optional[List[int]] = None,
freq: str = "W",
) -> List[Alert]:
"""
Run all enabled alert detectors.
Args:
threshold: Absolute case count threshold.
zscore_threshold: Z-score threshold for statistical alert.
use_endemic_channel: Use endemic channel (requires ≥3 historical years).
historical_years: Years to use for endemic channel baseline.
freq: Aggregation frequency ('D', 'W', 'ME').
Returns:
List of Alert objects, sorted by period.
"""
alerts: List[Alert] = []
agg = self.dataset.aggregate(freq=freq)
values = agg[self.dataset.cases_col].values.astype(float)
periods = agg["period"].values
# Absolute threshold
if threshold is not None:
for period, val in zip(periods, values):
if val >= threshold:
severity = "epidemic" if val >= threshold * 2 else "alert"
alerts.append(Alert(
period=period, value=float(val),
threshold=float(threshold), kind="threshold",
severity=severity,
message=(
f"{val:.0f} cas ≥ seuil {threshold:.0f}"
),
))
# Z-score
if len(values) >= 4:
mean = np.mean(values)
std = np.std(values)
if std > 0:
zscores = (values - mean) / std
for period, val, z in zip(periods, values, zscores):
if z >= zscore_threshold:
severity = "epidemic" if z >= zscore_threshold * 1.5 else "warning"
alerts.append(Alert(
period=period, value=float(val),
threshold=float(mean + zscore_threshold * std),
kind="zscore", severity=severity,
message=f"Z-score={z:.2f} ≥ {zscore_threshold}",
))
# Endemic channel
if use_endemic_channel:
try:
channel = self.dataset.endemic_channel(historical_years)
week_map = dict(zip(channel["weeks"], channel["p_high"]))
import pandas as pd
for period, val in zip(periods, values):
p = pd.Timestamp(period)
week = p.isocalendar()[1]
if week in week_map and val > week_map[week]:
alerts.append(Alert(
period=period, value=float(val),
threshold=float(week_map[week]),
kind="endemic_channel", severity="alert",
message=(
f"Semaine {week}: {val:.0f} cas "
f"> P75 historique {week_map[week]:.0f}"
),
))
except Exception:
pass
# Sort by period
try:
alerts.sort(key=lambda a: str(a.period))
except Exception:
pass
return alerts
[docs]
def alert_summary(self, alerts: List[Alert]) -> Dict[str, Any]:
"""Summarise a list of alerts."""
if not alerts:
return {"n_alerts": 0, "severity_counts": {}}
from collections import Counter
sev = Counter(a.severity for a in alerts)
return {
"n_alerts": len(alerts),
"severity_counts": dict(sev),
"first_alert": str(alerts[0].period),
"last_alert": str(alerts[-1].period),
}
# Module-level convenience functions
[docs]
def from_dhis2_csv(
path: Union[str, Path],
date_col: str = "periodName",
cases_col: str = "value",
district_col: str = "orgUnitName",
**kwargs,
) -> SurveillanceDataset:
"""
Load a DHIS2 standard CSV export.
DHIS2 exports typically have columns:
periodName, orgUnitName, dataElementName, value, …
Args:
path: Path to DHIS2 CSV export.
date_col: Column with period label.
cases_col: Column with case count value.
district_col: Column with organisation unit name.
**kwargs: Passed to pd.read_csv.
Returns:
SurveillanceDataset.
"""
return SurveillanceDataset.from_csv(
path, date_col=date_col, cases_col=cases_col,
district_col=district_col, **kwargs,
)
[docs]
def compute_attack_rate(
cases: int,
population: int,
per: int = 100_000,
) -> float:
"""
Compute attack rate.
Args:
cases: Number of cases.
population: Population at risk.
per: Rate denominator (default 100,000).
Returns:
Attack rate per `per` population.
"""
if population <= 0:
raise ValueError(f"population must be > 0, got {population}.")
return cases / population * per
[docs]
def endemic_channel(
dataset: SurveillanceDataset,
historical_years: Optional[List[int]] = None,
percentiles: Tuple[float, float, float] = (25, 50, 75),
) -> Dict[str, Any]:
"""Module-level alias for dataset.endemic_channel()."""
return dataset.endemic_channel(historical_years, percentiles)
[docs]
def aggregate_by(
dataset: SurveillanceDataset,
freq: str = "W",
group_by: Optional[List[str]] = None,
):
"""Module-level alias for dataset.aggregate()."""
return dataset.aggregate(freq=freq, group_by=group_by)
__all__ = [
"SurveillanceDataset",
"AlertEngine",
"Alert",
"from_dhis2_csv",
"compute_attack_rate",
"endemic_channel",
"aggregate_by",
]