#!/usr/bin/env python3
"""
ensemble_selector_v1.py — Regime-Aware DEX LP Ensemble Selector

Combines 4 algorithms by detected market regime:
  range_order          → bull trend + flow (drift > +3%)
  downtrend_harvester  → downtrend (drift < -3%)
  vol_gated            → neutral, fee_rate > vol_k*σ²
  idle                 → no evidence / thin flow

Self-contained: no external imports from cl_fee_replay_fast_npz_v3.
"""
from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Dict, List, Any

import numpy as np
import pandas as pd

SCRIPT_VERSION = "ensemble_selector_v1_2026_05_04"


# ── Math helpers ──────────────────────────────────────────────────────────────

def sqrt_raw(p: float, d0: int = 6, d1: int = 18) -> float:
    """sqrt(1/price_in_token1_per_token0) for CL math."""
    return math.sqrt(10 ** (d1 - d0) / max(p, 1e-300))


def liquidity_for_capital(cap: float, p0: float, lo: float, up: float,
                           d0: int = 6, d1: int = 18) -> float:
    sp = sqrt_raw(p0, d0, d1)
    sa = sqrt_raw(up, d0, d1)
    sb = sqrt_raw(lo, d0, d1)
    if p0 <= lo:
        uval = (sb - sa) / (10 ** d1) * p0
    elif p0 >= up:
        uval = (sb - sa) / (sa * sb) / (10 ** d0)
    else:
        uval = (sb - sp) / (sp * sb) / (10 ** d0) + (sp - sa) / (10 ** d1) * p0
    return cap / uval if uval > 1e-300 else 0.0


def position_value(L: float, p: float, lo: float, up: float,
                   d0: int = 6, d1: int = 18) -> float:
    sp = sqrt_raw(p, d0, d1)
    sa = sqrt_raw(up, d0, d1)
    sb = sqrt_raw(lo, d0, d1)
    if p <= lo:
        a0, a1 = 0.0, L * (sb - sa)
    elif p >= up:
        a0, a1 = L * (sb - sa) / (sa * sb), 0.0
    else:
        a0 = L * (sb - sp) / (sp * sb)
        a1 = L * (sp - sa)
    return a0 / (10 ** d0) + a1 / (10 ** d1) * p


def max_drawdown(eq: np.ndarray) -> float:
    pk = np.maximum.accumulate(np.where(eq > 0, eq, 1e-300))
    return float(np.nanmin(eq / pk - 1) * 100)


# ── Regime feature computation ────────────────────────────────────────────────

def window_features(price: np.ndarray, input_usd: np.ndarray,
                    ts: np.ndarray, end_idx: int,
                    lookback_hours: float) -> Dict[str, float]:
    """Compute regime features over a rolling window ending at end_idx."""
    end_ts = int(ts[end_idx])
    start_ts = end_ts - int(lookback_hours * 3600)
    lo_i = int(np.searchsorted(ts, start_ts, side="left"))
    lo_i = min(max(0, lo_i), end_idx)
    p = price[lo_i: end_idx + 1]
    flow = input_usd[lo_i: end_idx + 1]
    t = ts[lo_i: end_idx + 1]

    if len(p) < 3:
        return {
            "drift_pct": 0.0, "path_pct": 0.0, "vol_pct": 0.0,
            "move_p95_pct": 1.0, "flow_usd_per_hour": 0.0,
            "events_per_hour": 0.0, "toxicity": 0.0,
        }

    lr = np.diff(np.log(np.maximum(p, 1e-300)))
    abs_lr = np.abs(lr)
    drift_pct = float((p[-1] / p[0] - 1.0) * 100.0)
    path_pct = float(np.sum(abs_lr) * 100.0)
    vol_pct = float(np.std(lr) * math.sqrt(max(1, len(lr))) * 100.0)
    move_p95_pct = float(np.nanpercentile(abs_lr * 100.0, 95)) if len(abs_lr) else 1.0
    hours = max(1.0 / 60.0, (int(t[-1]) - int(t[0])) / 3600.0)
    toxicity = min(1.0, abs(drift_pct) / max(path_pct, 1e-9))

    return {
        "drift_pct": drift_pct,
        "path_pct": max(path_pct, 1e-9),
        "vol_pct": vol_pct,
        "move_p95_pct": max(0.01, move_p95_pct),
        "flow_usd_per_hour": float(np.sum(flow) / hours),
        "events_per_hour": float(len(p) / hours),
        "toxicity": float(toxicity),
    }


# ── Regime selector ───────────────────────────────────────────────────────────

def select_regime(feat: Dict[str, float], fee_rate: float, capital: float,
                  vol_k: float = 2.0) -> str:
    """
    Regime selection rules from the ensemble orchestrator:
      1. idle: thin flow / no evidence
      2. range_order: bull trend + flow
      3. downtrend_harvester: falling market
      4. vol_gated: neutral, fee covers LVR
    """
    drift = feat["drift_pct"]
    path = feat["path_pct"]
    vol = feat["vol_pct"]
    flow = feat["flow_usd_per_hour"]
    events_h = feat["events_per_hour"]
    fee_budget_pct_per_day = fee_rate * flow * 24.0 / max(capital, 1e-9) * 100.0

    if events_h < 1.0 or fee_budget_pct_per_day < 0.05:
        return "idle"

    trend = drift / path

    if drift > 3.0 and trend > 0.04 and events_h >= 2.0:
        return "range_order"

    if drift < -3.0 or trend < -0.05:
        return "downtrend_harvester"

    vol_sigma = vol / 100.0
    if fee_rate > vol_k * vol_sigma ** 2:
        return "vol_gated"

    return "idle"


# ── Per-regime LP parameters ──────────────────────────────────────────────────

def regime_params(regime: str, feat: Dict[str, float]) -> Dict[str, float]:
    """Return lower_pct/upper_pct/side for the chosen regime."""
    drift = feat["drift_pct"]
    vol = feat["vol_pct"]
    move_p95 = feat["move_p95_pct"]

    if regime == "range_order":
        # Narrow range above market → synthetic limit order on uptrend
        lower = float(np.clip(max(0.05, move_p95 * 0.20), 0.05, 1.5))
        upper = float(np.clip(max(0.5, move_p95 * 1.5), 0.5, 6.0))
        return {"lower_pct": lower, "upper_pct": upper, "side": "up"}

    if regime == "downtrend_harvester":
        # Wide lower tail, tiny upper — collects fees while price drifts down
        lower = float(np.clip(max(85.0, abs(drift) * 2.5, vol * 4.0), 85.0, 99.0))
        upper = float(np.clip(max(0.01, move_p95 * 0.10), 0.01, 0.20))
        return {"lower_pct": lower, "upper_pct": upper, "side": "down"}

    if regime == "vol_gated":
        # Symmetric range sized to vol
        width = float(np.clip(max(3.0, move_p95 * 8.0, vol * 0.9), 3.0, 50.0))
        return {"lower_pct": width, "upper_pct": width, "side": "both"}

    return {"lower_pct": 0.0, "upper_pct": 0.0, "side": "idle"}


# ── Main ensemble backtester ──────────────────────────────────────────────────

def run_ensemble(
    price: np.ndarray,
    input_usd: np.ndarray,
    active_liq: np.ndarray,
    ts: np.ndarray,
    d0: int,
    d1: int,
    capital: float,
    fee_rate: float,
    rebalance_hours: float = 168.0,
    lookback_hours: float = 168.0,
    vol_k: float = 2.0,
) -> Dict[str, Any]:
    n = len(price)
    interval = max(1, int(rebalance_hours * 3600))

    cash = float(capital)
    L = lo = up = 0.0
    fees_cum = 0.0
    deployed = False
    active_regime = "idle"
    active_side = "idle"

    equity_arr = np.empty(n, dtype=np.float64)
    share_arr = np.zeros(n, dtype=np.float64)
    in_range_arr = np.zeros(n, dtype=np.int8)
    regime_log: List[Dict] = []
    rebalances = 0
    regime_counts: Dict[str, int] = {
        "idle": 0, "range_order": 0,
        "downtrend_harvester": 0, "vol_gated": 0,
    }

    def enter_position(p0: float, rp: Dict) -> None:
        nonlocal L, lo, up, cash, deployed, fees_cum
        l = p0 * (1.0 - rp["lower_pct"] / 100.0)
        u = p0 * (1.0 + rp["upper_pct"] / 100.0)
        Lv = liquidity_for_capital(cash, p0, l, u, d0, d1)
        if Lv > 0:
            L, lo, up, cash, deployed = Lv, l, u, 0.0, True
            fees_cum = 0.0

    def close_position(p: float) -> None:
        nonlocal L, lo, up, cash, fees_cum, deployed
        if deployed:
            cash = position_value(L, p, lo, up, d0, d1) + fees_cum
            fees_cum = 0.0
            L = lo = up = 0.0
            deployed = False

    # Initial regime decision
    feat0 = window_features(price, input_usd, ts, 0, lookback_hours)
    regime0 = select_regime(feat0, fee_rate, cash, vol_k)
    rp0 = regime_params(regime0, feat0)
    active_regime = regime0
    active_side = rp0["side"]
    regime_counts[regime0] = regime_counts.get(regime0, 0) + 1
    if regime0 != "idle" and cash > 0:
        enter_position(float(price[0]), rp0)
        rebalances += 1

    idx = 0
    while idx < n:
        next_ts = int(ts[idx]) + interval
        next_idx = int(np.searchsorted(ts, next_ts, side="left"))
        next_idx = min(max(idx + 1, next_idx), n)

        # For range_order regime: check if price crosses upper (take profit)
        if deployed and active_regime == "range_order" and active_side == "up":
            seg_prices = price[idx:next_idx]
            cross = np.where(seg_prices >= up)[0]
            if len(cross):
                exit_i = idx + int(cross[0])
                # Fill equity up to exit
                if idx < exit_i:
                    sl = slice(idx, exit_i)
                    pseg = price[sl]
                    al_seg = active_liq[sl]
                    in_r = (pseg >= lo) & (pseg <= up)
                    in_range_arr[sl] = in_r.astype(np.int8)
                    sh = np.where(al_seg > 0, L / (al_seg + L), 0.0)
                    share_arr[sl] = sh
                    fees_earned = np.sum(sh[in_r] * fee_rate * input_usd[sl][in_r])
                    fees_cum += float(fees_earned)
                    equity_arr[sl] = np.array([
                        position_value(L, pseg[k], lo, up, d0, d1) + fees_cum
                        for k in range(len(pseg))
                    ])
                # Close at cross point
                p_exit = float(price[exit_i])
                close_position(p_exit)
                equity_arr[exit_i] = cash
                in_range_arr[exit_i] = 0
                rebalances += 1
                # Re-enter with new range immediately after
                feat_x = window_features(price, input_usd, ts, exit_i, lookback_hours)
                regime_x = select_regime(feat_x, fee_rate, cash, vol_k)
                rp_x = regime_params(regime_x, feat_x)
                active_regime = regime_x
                active_side = rp_x["side"]
                regime_counts[regime_x] = regime_counts.get(regime_x, 0) + 1
                if regime_x != "idle" and cash > 0:
                    enter_position(p_exit, rp_x)
                    rebalances += 1
                idx = exit_i + 1 if exit_i + 1 < n else n
                continue

        # Fill equity for this segment
        sl = slice(idx, next_idx)
        pseg = price[sl]
        al_seg = active_liq[sl]
        vseg = input_usd[sl]

        if deployed and L > 0:
            in_r = (pseg >= lo) & (pseg <= up)
            in_range_arr[sl] = in_r.astype(np.int8)
            sh = np.where(al_seg > 0, L / (al_seg + L), 0.0)
            share_arr[sl] = sh
            seg_fees = float(np.sum(sh[in_r] * fee_rate * vseg[in_r]))
            fees_cum += seg_fees
            equity_arr[sl] = np.array([
                position_value(L, pseg[k], lo, up, d0, d1) + fees_cum
                for k in range(len(pseg))
            ])
        else:
            equity_arr[sl] = cash
            in_range_arr[sl] = 0
            share_arr[sl] = 0.0

        idx = next_idx
        if idx >= n:
            break

        # Rebalance: close old position and select new regime
        if deployed:
            close_position(float(price[idx]))
            rebalances += 1

        feat = window_features(price, input_usd, ts, idx, lookback_hours)
        regime = select_regime(feat, fee_rate, cash, vol_k)
        rp = regime_params(regime, feat)
        active_regime = regime
        active_side = rp["side"]
        regime_counts[regime] = regime_counts.get(regime, 0) + 1
        regime_log.append({"idx": idx, "ts": int(ts[idx]), "regime": regime,
                           "drift": round(feat["drift_pct"], 2),
                           "vol": round(feat["vol_pct"], 2),
                           "events_h": round(feat["events_per_hour"], 1)})

        if regime != "idle" and cash > 0:
            enter_position(float(price[idx]), rp)
            rebalances += 1

    # Close any open position at end
    if deployed:
        close_position(float(price[-1]))
        equity_arr[-1] = cash

    # Metrics
    valid_share = share_arr[in_range_arr.astype(bool)]
    days = (int(ts[-1]) - int(ts[0])) / 86400.0
    ret_pct = (equity_arr[-1] / capital - 1.0) * 100.0
    mdd = max_drawdown(equity_arr)
    pnl_mdd = abs(ret_pct) / max(abs(mdd), 0.001) if ret_pct > 0 else 0.0
    annual_pct = ret_pct / max(days, 0.1) * 365.0

    return {
        "return_pct": round(ret_pct, 3),
        "annual_pct": round(annual_pct, 1),
        "mdd_pct": round(mdd, 3),
        "pnl_mdd": round(pnl_mdd, 3),
        "equity_end": round(float(equity_arr[-1]), 4),
        "days": round(days, 2),
        "rebalances": rebalances,
        "time_in_range_pct": round(float(in_range_arr.mean() * 100), 2),
        "avg_share_pct": round(float(valid_share.mean() * 100) if len(valid_share) else 0.0, 4),
        "p95_share_pct": round(float(np.percentile(valid_share, 95) * 100) if len(valid_share) else 0.0, 4),
        "p99_share_pct": round(float(np.percentile(valid_share, 99) * 100) if len(valid_share) else 0.0, 4),
        "regime_counts": regime_counts,
        "regime_log": regime_log,
        "params": f"rebal={rebalance_hours}h look={lookback_hours}h vol_k={vol_k}",
    }


# ── Grid runner ───────────────────────────────────────────────────────────────

REBAL_GRID = [72.0, 168.0, 336.0, 672.0]
LOOKBACK_GRID = [72.0, 168.0, 336.0]
VOL_K_GRID = [1.0, 2.0, 5.0, 10.0]


def run_on_npz(npz_path: str, capital: float, d0: int, d1: int,
               out_dir: Path) -> pd.DataFrame:
    z = np.load(npz_path, allow_pickle=True)
    price = z["price"].astype(np.float64)
    input_usd = z["input_usd"].astype(np.float64)
    active_liq = z["active_liquidity"].astype(np.float64)
    ts = z["ts"].astype(np.int64)
    try:
        meta = json.loads(str(z["meta_json"]))
    except Exception:
        meta = {}

    fee_rate = float(meta.get("fee_rate", 0.003))
    pool_name = meta.get("pool_name", Path(npz_path).stem)
    rows_swap = meta.get("rows_swap", len(price))
    days = (int(ts[-1]) - int(ts[0])) / 86400.0

    print(f"\n{'='*60}")
    print(f"NPZ: {Path(npz_path).name}")
    print(f"Pool: {pool_name} | {rows_swap} swaps | {days:.1f}d | fee_rate={fee_rate:.4f}")
    print(f"Price: {price.min():.4f}–{price.max():.4f} | Vol: ${input_usd.sum():,.0f}")

    rows: List[Dict] = []
    for rebal in REBAL_GRID:
        for look in LOOKBACK_GRID:
            for vk in VOL_K_GRID:
                r = run_ensemble(price, input_usd, active_liq, ts,
                                 d0, d1, capital, fee_rate,
                                 rebalance_hours=rebal,
                                 lookback_hours=look,
                                 vol_k=vk)
                rows.append({
                    "npz": Path(npz_path).name,
                    "pool": pool_name,
                    "days": r["days"],
                    "capital": capital,
                    "rebal_h": rebal,
                    "look_h": look,
                    "vol_k": vk,
                    "return_pct": r["return_pct"],
                    "annual_pct": r["annual_pct"],
                    "mdd_pct": r["mdd_pct"],
                    "pnl_mdd": r["pnl_mdd"],
                    "p99_share_pct": r["p99_share_pct"],
                    "time_in_range_pct": r["time_in_range_pct"],
                    "rebalances": r["rebalances"],
                    "regime_idle": r["regime_counts"].get("idle", 0),
                    "regime_range_order": r["regime_counts"].get("range_order", 0),
                    "regime_downtrend": r["regime_counts"].get("downtrend_harvester", 0),
                    "regime_vol_gated": r["regime_counts"].get("vol_gated", 0),
                    "pass_strict": int(
                        r["return_pct"] > 0
                        and r["mdd_pct"] > -20.0
                        and r["pnl_mdd"] >= 2.0
                        and r["p99_share_pct"] < 10.0
                    ),
                })

    df = pd.DataFrame(rows)

    # Save per-NPZ results
    stem = Path(npz_path).stem
    (out_dir / stem).mkdir(parents=True, exist_ok=True)
    df.to_csv(out_dir / stem / "summary.csv", index=False)

    valid = df[df["pass_strict"] == 1]
    print(f"\nGrid: {len(df)} combos | Pass strict: {len(valid)}")
    if not valid.empty:
        best = valid.sort_values("annual_pct", ascending=False).iloc[0]
        print(f"  BEST: rebal={best.rebal_h}h look={best.look_h}h vol_k={best.vol_k} "
              f"→ {best.annual_pct:.0f}% ann MDD={best.mdd_pct:.1f}% "
              f"PnL/MDD={best.pnl_mdd:.2f} p99={best.p99_share_pct:.2f}%")
    else:
        best_any = df.sort_values("annual_pct", ascending=False).iloc[0]
        print(f"  No strict pass. Best: {best_any.annual_pct:.0f}% ann "
              f"MDD={best_any.mdd_pct:.1f}% PnL/MDD={best_any.pnl_mdd:.2f}")
    return df


def main() -> None:
    ap = argparse.ArgumentParser(description="Ensemble Selector DEX LP Backtester")
    ap.add_argument("--npzs", nargs="+", required=True)
    ap.add_argument("--out-dir", default="DEX_REPORTS/ensemble_test")
    ap.add_argument("--capital", type=float, default=600.0)
    ap.add_argument("--dec0", type=int, default=6)
    ap.add_argument("--dec1", type=int, default=18)
    args = ap.parse_args()

    print(f"[{SCRIPT_VERSION}]")
    print(f"Capital: ${args.capital} | d0={args.dec0} d1={args.dec1}")

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    all_dfs: List[pd.DataFrame] = []
    for npz in args.npzs:
        df = run_on_npz(npz, args.capital, args.dec0, args.dec1, out_dir)
        all_dfs.append(df)

    if all_dfs:
        combined = pd.concat(all_dfs, ignore_index=True)
        combined.to_csv(out_dir / "cross_pool_all.csv", index=False)

        # Best per NPZ (strict pass preferred, else best return)
        summary_rows: List[Dict] = []
        for npz_name, grp in combined.groupby("npz"):
            strict = grp[grp["pass_strict"] == 1]
            pool = grp["pool"].iloc[0]
            days = grp["days"].iloc[0]
            if not strict.empty:
                b = strict.sort_values("annual_pct", ascending=False).iloc[0]
                status = "PASS"
            else:
                b = grp.sort_values("annual_pct", ascending=False).iloc[0]
                status = "FAIL"
            summary_rows.append({
                "npz": npz_name, "pool": pool, "days": days,
                "status": status, "n_pass_strict": len(strict),
                "best_annual_pct": b.annual_pct,
                "best_return_pct": b.return_pct,
                "best_mdd_pct": b.mdd_pct,
                "best_pnl_mdd": b.pnl_mdd,
                "best_p99_share": b.p99_share_pct,
                "best_rebal_h": b.rebal_h,
                "best_look_h": b.look_h,
                "best_vol_k": b.vol_k,
            })

        summary = pd.DataFrame(summary_rows)
        summary.to_csv(out_dir / "ensemble_cross_pool_results.csv", index=False)

        n_pass = int((summary["status"] == "PASS").sum())
        n_total = len(summary)
        print(f"\n{'='*60}")
        print(f"ENSEMBLE SELECTOR RESULTS: {n_pass}/{n_total} pools PASS strict")
        print(f"{'='*60}")
        for _, r in summary.iterrows():
            icon = "✓" if r.status == "PASS" else "✗"
            print(f"  {icon} {r.npz[:40]:40s} → {r.best_annual_pct:6.0f}% ann "
                  f"MDD={r.best_mdd_pct:5.1f}% PnL/MDD={r.best_pnl_mdd:4.2f} "
                  f"[{r.status}]")

        if n_pass >= 3:
            print(f"\n>>> {n_pass}/{n_total} pass — RECOMMEND paper trading")
        else:
            print(f"\n>>> Only {n_pass}/{n_total} pass — needs more work")

        print(f"\nResults saved to: {out_dir}/")


if __name__ == "__main__":
    main()
