Source code for welltest_pta.validation.cross_validation

r"""
welltest_pta.validation.cross_validation
========================================
Cross-validation scores for the auto event-detector.

Three complementary scores are computed and printed when the user calls
``WellTest.from_file(..., cross_validate=True)`` or directly via
:func:`cross_validate_detector`:

1. **Bootstrap event-count stability**

   Re-run the detector on :math:`K` random resamples of the data
   (each downsampled to a fraction :math:`f` of the original gauge rate)
   and report the standard deviation of the number of detected DDs and
   BUs. Low σ ⇒ stable detection.

2. **Parameter sensitivity sweep**

   Vary each detector parameter (``hampel_sigma``, ``spike_percentile``,
   ``min_pta_dp_psi``) by ±20 % and report how many events change. Few
   changes ⇒ robust to parameter choice.

3. **Edge-position consistency**

   For each detected event, compare the start/end indices across bootstrap
   replicas. Report the median ± IQR positional drift in samples and the
   "Jaccard overlap" of the masked time axis.

The combined score is a 0–100 confidence index:

==========  =============================================
80 – 100    Highly robust — manual review optional
60 –  80    Reasonable — spot-check critical events
40 –  60    Marginal — recommend manual splitting
 0 –  40    Unstable — manual splitting strongly advised
==========  =============================================
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import pandas as pd

from welltest_pta.detection.detector import (
    EventDetector,
    EventDetectorConfig,
)

logger = logging.getLogger(__name__)


# ─────────────────────────────────────────────────────────────────────────────
# Result container
# ─────────────────────────────────────────────────────────────────────────────

[docs] @dataclass class DetectorCVResult: """Cross-validation report from :func:`cross_validate_detector`.""" # Bootstrap n_bootstrap: int bootstrap_n_dd_mean: float bootstrap_n_dd_std: float bootstrap_n_bu_mean: float bootstrap_n_bu_std: float # Sensitivity (per-parameter) sensitivity: dict[str, dict[str, float]] # Edge consistency edge_jaccard_mean: float edge_jaccard_std: float # Combined score 0–100 overall_score: float grade: str # Reference detection (for comparison) ref_n_dd: int ref_n_bu: int raw_data: dict = field(default_factory=dict, repr=False)
[docs] def print_report(self) -> None: """Pretty-print the CV report to stdout.""" sep = "═" * 72 print(f"\n{sep}") print(" EVENT-DETECTOR CROSS-VALIDATION REPORT") print(sep) print(f" Reference detection: {self.ref_n_dd} drawdowns, {self.ref_n_bu} buildups") print() print(f" Bootstrap stability (K = {self.n_bootstrap} replicas):") print(f" n_drawdowns: {self.bootstrap_n_dd_mean:6.2f} ± {self.bootstrap_n_dd_std:.2f}") print(f" n_buildups: {self.bootstrap_n_bu_mean:6.2f} ± {self.bootstrap_n_bu_std:.2f}") print() print(f" Edge-position consistency (Jaccard overlap):") print(f" mean = {self.edge_jaccard_mean:.3f} (1.0 = perfect)") print(f" std = {self.edge_jaccard_std:.3f}") print() print(" Parameter sensitivity (Δ events under ±20 % perturbation):") for pname, info in self.sensitivity.items(): print(f" {pname:<22s} Δ_dd = {info['delta_n_dd']:+d}, " f"Δ_bu = {info['delta_n_bu']:+d}") print() print(f" ─ OVERALL CV SCORE: {self.overall_score:5.1f} / 100 " f"({self.grade})") print(sep + "\n")
# ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _count_pta_events( df: pd.DataFrame, ) -> tuple[int, int, list[tuple[int, int, str]]]: """Count DD / BU events in an annotated DataFrame; also return RLE list.""" if "event" not in df.columns: return 0, 0, [] labels = df["event"] if len(labels) == 0: return 0, 0, [] groups, cur, start = [], labels.iloc[0], 0 for i in range(1, len(labels)): if labels.iloc[i] != cur: groups.append((start, i, cur)) cur, start = labels.iloc[i], i groups.append((start, len(labels), cur)) pta_groups = [(s, e, l) for s, e, l in groups if l in ("drawdown", "buildup")] n_dd = sum(1 for _, _, l in pta_groups if l == "drawdown") n_bu = sum(1 for _, _, l in pta_groups if l == "buildup") return n_dd, n_bu, pta_groups def _jaccard_pta_mask( ref_df: pd.DataFrame, alt_df: pd.DataFrame ) -> float: """ Jaccard overlap of the "is-PTA" masks aligned on the timestamp index. Both DataFrames must have a ``timestamp`` column and an ``event`` column. Resamples to the reference timeline by nearest-neighbour. """ if "event" not in ref_df.columns or "event" not in alt_df.columns: return 0.0 ref_pta = ref_df["event"].isin(("drawdown", "buildup")).to_numpy() # Reindex alt_df labels to the reference timestamps via nearest-neighbour alt_ts = pd.Series(alt_df["event"].values, index=alt_df["timestamp"]) alt_ts = alt_ts.sort_index() alt_aligned = alt_ts.reindex(ref_df["timestamp"], method="nearest") alt_pta = alt_aligned.isin(("drawdown", "buildup")).to_numpy() inter = np.logical_and(ref_pta, alt_pta).sum() union = np.logical_or(ref_pta, alt_pta).sum() return float(inter / union) if union > 0 else 0.0 # ───────────────────────────────────────────────────────────────────────────── # Public utilities # ─────────────────────────────────────────────────────────────────────────────
[docs] def bootstrap_score( df: pd.DataFrame, cfg: Optional[EventDetectorConfig] = None, n_bootstrap: int = 10, downsample_frac: float = 0.85, rng: Optional[np.random.Generator] = None, ) -> dict[str, float]: """Down-sample bootstrap of detector event counts.""" if rng is None: rng = np.random.default_rng(42) n = len(df) keep_n = max(int(n * downsample_frac), 100) n_dds, n_bus = [], [] for _ in range(n_bootstrap): idx = np.sort(rng.choice(n, size=keep_n, replace=False)) sub = df.iloc[idx].reset_index(drop=True) try: det = EventDetector(cfg=cfg or EventDetectorConfig()) ann = det.detect(sub) n_dd, n_bu, _ = _count_pta_events(ann) n_dds.append(n_dd) n_bus.append(n_bu) except Exception as e: # pragma: no cover logger.debug("Bootstrap replica failed: %s", e) if not n_dds: return {"n_dd_mean": 0.0, "n_dd_std": np.nan, "n_bu_mean": 0.0, "n_bu_std": np.nan, "n_replicas_ok": 0} return { "n_dd_mean": float(np.mean(n_dds)), "n_dd_std": float(np.std(n_dds)), "n_bu_mean": float(np.mean(n_bus)), "n_bu_std": float(np.std(n_bus)), "n_replicas_ok": len(n_dds), }
[docs] def parameter_sensitivity( df: pd.DataFrame, cfg: Optional[EventDetectorConfig] = None, perturbation: float = 0.20, ) -> dict[str, dict[str, int]]: """Sweep ±perturbation around each detector parameter.""" base_cfg = cfg or EventDetectorConfig() base_det = EventDetector(cfg=base_cfg) base_ann = base_det.detect(df) base_n_dd, base_n_bu, _ = _count_pta_events(base_ann) params_to_perturb = { "hampel_sigma": base_cfg.hampel_sigma, "spike_percentile": base_cfg.spike_percentile, "min_pta_dp_psi": base_cfg.min_pta_dp_psi, "tail_trim_dev_n_sigma": base_cfg.tail_trim_dev_n_sigma, } out: dict[str, dict[str, int]] = {} for name, val in params_to_perturb.items(): max_dd_diff = 0 max_bu_diff = 0 for sign in (-1, +1): new_val = val * (1.0 + sign * perturbation) try: cfg_perturbed = EventDetectorConfig( **{k: getattr(base_cfg, k) for k in base_cfg.__dataclass_fields__} ) setattr(cfg_perturbed, name, new_val) det = EventDetector(cfg=cfg_perturbed) ann = det.detect(df) n_dd, n_bu, _ = _count_pta_events(ann) if abs(n_dd - base_n_dd) > abs(max_dd_diff): max_dd_diff = n_dd - base_n_dd if abs(n_bu - base_n_bu) > abs(max_bu_diff): max_bu_diff = n_bu - base_n_bu except Exception as e: # pragma: no cover logger.debug("Sensitivity perturbation failed for %s: %s", name, e) out[name] = {"delta_n_dd": int(max_dd_diff), "delta_n_bu": int(max_bu_diff)} return out
# ───────────────────────────────────────────────────────────────────────────── # Top-level function # ─────────────────────────────────────────────────────────────────────────────
[docs] def cross_validate_detector( df: pd.DataFrame, cfg: Optional[EventDetectorConfig] = None, n_bootstrap: int = 8, downsample_frac: float = 0.85, perturbation: float = 0.20, seed: int = 42, print_report: bool = True, ) -> DetectorCVResult: """ Run all three CV checks and return a :class:`DetectorCVResult`. Parameters ---------- df Parsed gauge DataFrame (output of :func:`welltest_pta.parser.parse`). cfg Detector configuration. Defaults to standard V8.1 settings. n_bootstrap Number of bootstrap replicas (≥4 recommended). downsample_frac Fraction of points retained per replica. perturbation Fractional change applied to each parameter (e.g. 0.20 = ±20 %). seed RNG seed for reproducibility. print_report If True (default) prints the human-readable report. """ rng = np.random.default_rng(seed) # Reference detection base_cfg = cfg or EventDetectorConfig() ref_det = EventDetector(cfg=base_cfg) ref_ann = ref_det.detect(df) ref_n_dd, ref_n_bu, _ = _count_pta_events(ref_ann) # 1. Bootstrap bs = bootstrap_score(df, base_cfg, n_bootstrap, downsample_frac, rng) # 2. Sensitivity sens = parameter_sensitivity(df, base_cfg, perturbation) # 3. Edge consistency (re-run on bootstrap, compute Jaccard against reference) jaccards: list[float] = [] n = len(df) keep_n = max(int(n * downsample_frac), 100) for _ in range(n_bootstrap): idx = np.sort(rng.choice(n, size=keep_n, replace=False)) sub = df.iloc[idx].reset_index(drop=True) try: det = EventDetector(cfg=base_cfg) ann = det.detect(sub) jaccards.append(_jaccard_pta_mask(ref_ann, ann)) except Exception: pass j_mean = float(np.mean(jaccards)) if jaccards else 0.0 j_std = float(np.std(jaccards)) if jaccards else 0.0 # ── Composite score ── bs_pen = (bs["n_dd_std"] + bs["n_bu_std"]) bs_score = max(0.0, 1.0 - bs_pen / max(ref_n_dd + ref_n_bu, 1)) sens_total = sum(abs(v["delta_n_dd"]) + abs(v["delta_n_bu"]) for v in sens.values()) sens_score = max(0.0, 1.0 - sens_total / max(2 * (ref_n_dd + ref_n_bu) * len(sens), 1)) overall = 100.0 * (0.40 * bs_score + 0.40 * j_mean + 0.20 * sens_score) overall = float(np.clip(overall, 0.0, 100.0)) if overall >= 80: grade = "HIGHLY ROBUST" elif overall >= 60: grade = "REASONABLE" elif overall >= 40: grade = "MARGINAL — manual review recommended" else: grade = "UNSTABLE — manual splitting strongly advised" result = DetectorCVResult( n_bootstrap=n_bootstrap, bootstrap_n_dd_mean=bs["n_dd_mean"], bootstrap_n_dd_std=bs["n_dd_std"], bootstrap_n_bu_mean=bs["n_bu_mean"], bootstrap_n_bu_std=bs["n_bu_std"], sensitivity=sens, edge_jaccard_mean=j_mean, edge_jaccard_std=j_std, overall_score=overall, grade=grade, ref_n_dd=ref_n_dd, ref_n_bu=ref_n_bu, raw_data={"jaccards": jaccards}, ) if print_report: result.print_report() return result