Source code for episia.stats.stratified

"""
This module provides functions for stratified analysis, including
Mantel-Haenszel methods for adjusting for confounding variables
and testing for effect modification.
"""

import numpy as np
from typing import Union, Tuple, Optional, Dict, List
from dataclasses import dataclass
from enum import Enum
import warnings
from scipy import stats

from .contingency import Table2x2


[docs] class StratifiedMethod(Enum): """Methods for stratified analysis.""" MANTEL_HAENSZEL = "mantel_haenszel" DIRECT_STANDARDIZATION = "direct" INDIRECT_STANDARDIZATION = "indirect"
[docs] @dataclass class StratifiedTable: """Container for stratified 2x2 tables.""" tables: List[Table2x2] strata_names: Optional[List[str]] = None
[docs] def __post_init__(self): """Validate that all tables have the same structure.""" if not self.tables: raise ValueError("At least one table is required") if self.strata_names is None: self.strata_names = [f"Stratum_{i+1}" for i in range(len(self.tables))] elif len(self.strata_names) != len(self.tables): raise ValueError("Number of stratum names must match number of tables")
[docs] def __len__(self) -> int: return len(self.tables)
[docs] def __getitem__(self, idx): return self.tables[idx]
[docs] def to_dict(self) -> Dict: """Convert to dictionary representation.""" return { "n_strata": len(self.tables), "strata_names": self.strata_names, "tables": [table.to_dict() for table in self.tables] }
[docs] @dataclass class MantelHaenszelResult: """Result object for Mantel-Haenszel analysis.""" common_or: float common_rr: float common_rd: float or_ci: Tuple[float, float] rr_ci: Tuple[float, float] chi2_mh: float p_value: float cochran_q: float q_p_value: float i_squared: float tau_squared: float
[docs] def __repr__(self) -> str: return f"Mantel-Haenszel OR: {self.common_or:.3f} ({self.or_ci[0]:.3f}-{self.or_ci[1]:.3f})"
[docs] def summary(self) -> str: """Generate text summary.""" return (f"Mantel-Haenszel Analysis:\n" f" Common OR: {self.common_or:.3f} (95% CI: {self.or_ci[0]:.3f}-{self.or_ci[1]:.3f})\n" f" Common RR: {self.common_rr:.3f} (95% CI: {self.rr_ci[0]:.3f}-{self.rr_ci[1]:.3f})\n" f" Common RD: {self.common_rd:.3f}\n" f" Test for heterogeneity: χ²={self.cochran_q:.3f}, p={self.q_p_value:.3f}\n" f" I² = {self.i_squared:.1f}%, τ² = {self.tau_squared:.3f}")
[docs] @dataclass class DirectStandardizationResult: """Result object for direct standardization.""" crude_rate: float adjusted_rate: float standard_population: np.ndarray stratum_specific_rates: np.ndarray variance: float ci: Tuple[float, float]
[docs] def __repr__(self) -> str: return f"Directly Adjusted Rate: {self.adjusted_rate:.3f} ({self.ci[0]:.3f}-{self.ci[1]:.3f})"
[docs] def mantel_haenszel_or( stratified_tables: Union[StratifiedTable, List[Table2x2]], confidence: float = 0.95 ) -> MantelHaenszelResult: """ Calculate Mantel-Haenszel pooled odds ratio. Args: stratified_tables: StratifiedTable or list of Table2x2 objects confidence: Confidence level Returns: MantelHaenszelResult object Example: >>> table1 = Table2x2(10, 20, 30, 40) >>> table2 = Table2x2(15, 25, 35, 45) >>> result = mantel_haenszel_or([table1, table2]) """ if isinstance(stratified_tables, list): stratified_tables = StratifiedTable(stratified_tables) n_strata = len(stratified_tables) # Initialize sums for MH formulas sum_num_or = 0.0 sum_den_or = 0.0 sum_num_rr = 0.0 sum_den_rr = 0.0 sum_rd = 0.0 sum_var = 0.0 # For heterogeneity test stratum_or = [] stratum_weights = [] for table in stratified_tables.tables: a, b, c, d = table.a, table.b, table.c, table.d n = table.total # MH OR numerator and denominator num_or = a * d / n den_or = b * c / n sum_num_or += num_or sum_den_or += den_or # MH RR num_rr = a * (c + d) / n den_rr = (a + b) * c / n sum_num_rr += num_rr sum_den_rr += den_rr # MH RD sum_rd += (a * (c + d) - c * (a + b)) / n # Variance components if n > 1: R = (a * d) / n S = (b * c) / n P = (a + d) / n Q = (b + c) / n sum_var += (P * R + Q * S) / 2 # For heterogeneity if b * c > 0: stratum_or.append((a * d) / (b * c)) stratum_weights.append(1 / (1/a + 1/b + 1/c + 1/d)) # Calculate common measures common_or = sum_num_or / sum_den_or if sum_den_or > 0 else 0.0 common_rr = sum_num_rr / sum_den_rr if sum_den_rr > 0 else 0.0 common_rd = sum_rd / n_strata if n_strata > 0 else 0.0 # Calculate CI for OR (Robins et al. method) if sum_var > 0: z = stats.norm.ppf(1 - (1 - confidence) / 2) log_or = np.log(common_or) se_log_or = np.sqrt(sum_var / (sum_num_or * sum_den_or)) or_ci_lower = np.exp(log_or - z * se_log_or) or_ci_upper = np.exp(log_or + z * se_log_or) else: or_ci_lower, or_ci_upper = 0.0, 0.0 # CI for RR log_rr = np.log(common_rr) var_log_rr = (sum_num_rr / (sum_num_rr**2) + sum_den_rr / (sum_den_rr**2)) se_log_rr = np.sqrt(var_log_rr) if var_log_rr > 0 else 0.0 rr_ci_lower = np.exp(log_rr - z * se_log_rr) rr_ci_upper = np.exp(log_rr + z * se_log_rr) # Test for heterogeneity (Cochran's Q) if len(stratum_or) > 1: cochran_q = _cochran_q_test(stratum_or, stratum_weights) df = len(stratum_or) - 1 q_p_value = 1 - stats.chi2.cdf(cochran_q, df) if df > 0 else 1.0 # I² statistic if cochran_q > df: i_squared = max(0, (cochran_q - df) / cochran_q * 100) else: i_squared = 0.0 # Tau² (between-study variance) if cochran_q > df: c = sum(stratum_weights) - sum(w**2 for w in stratum_weights) / sum(stratum_weights) tau_squared = max(0, (cochran_q - df) / c) else: tau_squared = 0.0 else: cochran_q = 0.0 q_p_value = 1.0 i_squared = 0.0 tau_squared = 0.0 # Mantel-Haenszel chi-square test chi2_mh = _mantel_haenszel_chi2(stratified_tables) mh_p_value = 1 - stats.chi2.cdf(chi2_mh, 1) if chi2_mh > 0 else 1.0 return MantelHaenszelResult( common_or=common_or, common_rr=common_rr, common_rd=common_rd, or_ci=(or_ci_lower, or_ci_upper), rr_ci=(rr_ci_lower, rr_ci_upper), chi2_mh=chi2_mh, p_value=mh_p_value, cochran_q=cochran_q, q_p_value=q_p_value, i_squared=i_squared, tau_squared=tau_squared )
def _cochran_q_test(stratum_or: List[float], weights: List[float]) -> float: """Calculate Cochran's Q statistic for heterogeneity.""" if len(stratum_or) <= 1: return 0.0 # Inverse variance weights weighted_mean = sum(w * np.log(or_val) for w, or_val in zip(weights, stratum_or)) / sum(weights) Q = sum(w * (np.log(or_val) - weighted_mean)**2 for w, or_val in zip(weights, stratum_or)) return Q def _mantel_haenszel_chi2(stratified_tables: StratifiedTable) -> float: """Calculate Mantel-Haenszel chi-square statistic.""" sum_num = 0.0 sum_var = 0.0 for table in stratified_tables.tables: a, b, c, d = table.a, table.b, table.c, table.d n = table.total # Expected value of a under null expected_a = (a + b) * (a + c) / n if n > 0 else 0 sum_num += a - expected_a sum_var += (a + b) * (c + d) * (a + c) * (b + d) / (n**2 * (n - 1)) if n > 1 else 0 if sum_var > 0: chi2 = sum_num**2 / sum_var else: chi2 = 0.0 return chi2
[docs] def test_effect_modification( stratified_tables: StratifiedTable, method: str = "breslow_day" ) -> Dict[str, float]: """ Test for effect modification (interaction) across strata. Args: stratified_tables: StratifiedTable object method: 'breslow_day' or 'woolf' Returns: Dictionary with test statistics """ if len(stratified_tables) < 2: return {"statistic": 0.0, "p_value": 1.0, "df": 0} if method == "woolf": return _woolf_test(stratified_tables) else: # breslow_day return _breslow_day_test(stratified_tables)
def _woolf_test(stratified_tables: StratifiedTable) -> Dict[str, float]: """Woolf's test for homogeneity of odds ratios.""" stratum_or = [] stratum_weights = [] for table in stratified_tables.tables: if table.b * table.c > 0: or_val = (table.a * table.d) / (table.b * table.c) stratum_or.append(or_val) # Inverse variance weight weight = 1 / (1/table.a + 1/table.b + 1/table.c + 1/table.d) stratum_weights.append(weight) if len(stratum_or) < 2: return {"statistic": 0.0, "p_value": 1.0, "df": 0} # Weighted mean of log OR log_or = [np.log(or_val) for or_val in stratum_or] weighted_mean = sum(w * val for w, val in zip(stratum_weights, log_or)) / sum(stratum_weights) # Woolf's chi-square chi2 = sum(w * (val - weighted_mean)**2 for w, val in zip(stratum_weights, log_or)) df = len(stratum_or) - 1 p_value = 1 - stats.chi2.cdf(chi2, df) if df > 0 else 1.0 return { "test": "woolf", "statistic": chi2, "p_value": p_value, "df": df } def _breslow_day_test(stratified_tables: StratifiedTable) -> Dict[str, float]: """Breslow-Day test for homogeneity of odds ratios.""" # This is a simplified implementation # Full BD test requires iterative fitting return _woolf_test(stratified_tables) # Use Woolf as approximation
[docs] def direct_standardization( stratum_rates: np.ndarray, stratum_populations: np.ndarray, standard_population: np.ndarray, confidence: float = 0.95 ) -> DirectStandardizationResult: """ Perform direct standardization of rates. Args: stratum_rates: Rates in each stratum stratum_populations: Population in each stratum standard_population: Standard population distribution confidence: Confidence level Returns: DirectStandardizationResult object """ # Validate inputs n_strata = len(stratum_rates) if len(stratum_populations) != n_strata or len(standard_population) != n_strata: raise ValueError("All inputs must have same length") # Crude rate crude_rate = np.sum(stratum_rates * stratum_populations) / np.sum(stratum_populations) # Directly adjusted rate adjusted_rate = np.sum(stratum_rates * standard_population) / np.sum(standard_population) # Variance of adjusted rate variance = np.sum((standard_population**2 * stratum_rates * (1 - stratum_rates)) / stratum_populations) / (np.sum(standard_population)**2) # Confidence interval z = stats.norm.ppf(1 - (1 - confidence) / 2) se = np.sqrt(variance) ci_lower = adjusted_rate - z * se ci_upper = adjusted_rate + z * se return DirectStandardizationResult( crude_rate=crude_rate, adjusted_rate=adjusted_rate, standard_population=standard_population, stratum_specific_rates=stratum_rates, variance=variance, ci=(ci_lower, ci_upper) )
[docs] def indirect_standardization( observed_cases: np.ndarray, stratum_populations: np.ndarray, reference_rates: np.ndarray, confidence: float = 0.95 ) -> Dict[str, float]: """ Perform indirect standardization (SMR calculation). Args: observed_cases: Observed cases in each stratum stratum_populations: Population in each stratum reference_rates: Reference rates in each stratum confidence: Confidence level Returns: Dictionary with SMR and other statistics """ # Expected cases using reference rates expected_cases = np.sum(stratum_populations * reference_rates) total_observed = np.sum(observed_cases) # Standardized Mortality/Morbidity Ratio smr = total_observed / expected_cases if expected_cases > 0 else 0.0 # Confidence interval for SMR (Byar's approximation) if total_observed >= 10: z = stats.norm.ppf(1 - (1 - confidence) / 2) ci_lower = total_observed * (1 - 1/(9*total_observed) - z/(3*np.sqrt(total_observed)))**3 / expected_cases ci_upper = (total_observed + 1) * (1 - 1/(9*(total_observed+1)) + z/(3*np.sqrt(total_observed+1)))**3 / expected_cases else: # Exact Poisson CI ci_lower = stats.chi2.ppf((1-confidence)/2, 2*total_observed) / (2*expected_cases) ci_upper = stats.chi2.ppf(1-(1-confidence)/2, 2*(total_observed+1)) / (2*expected_cases) return { "smr": smr, "observed_cases": float(total_observed), "expected_cases": float(expected_cases), "ci_lower": ci_lower, "ci_upper": ci_upper, "confidence": confidence }
[docs] def stratified_by_variable( data, exposure_var: str, outcome_var: str, stratify_var: str ) -> StratifiedTable: """ Create stratified tables from DataFrame. Args: data: pandas DataFrame exposure_var: Exposure variable name outcome_var: Outcome variable name stratify_var: Variable to stratify by Returns: StratifiedTable object """ import pandas as pd if not hasattr(data, 'groupby'): raise ValueError("Data must be a pandas DataFrame or similar") tables = [] strata_names = [] for stratum, group in data.groupby(stratify_var): # Create 2x2 table for this stratum table = Table2x2( a=group[(group[exposure_var] == 1) & (group[outcome_var] == 1)].shape[0], b=group[(group[exposure_var] == 0) & (group[outcome_var] == 1)].shape[0], c=group[(group[exposure_var] == 1) & (group[outcome_var] == 0)].shape[0], d=group[(group[exposure_var] == 0) & (group[outcome_var] == 0)].shape[0] ) tables.append(table) strata_names.append(str(stratum)) return StratifiedTable(tables, strata_names)
# MODULE EXPORTS __all__ = [ 'StratifiedMethod', 'StratifiedTable', 'MantelHaenszelResult', 'DirectStandardizationResult', 'mantel_haenszel_or', 'test_effect_modification', 'direct_standardization', 'indirect_standardization', 'stratified_by_variable' ]