#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import pandas as pd


ROOT = Path(__file__).resolve().parents[2]
BACKTESTER = ROOT / "dex_platform" / "backtest" / "cl_adaptive_algorithmic_lp_v1.py"


def load_meta(npz: Path) -> Dict[str, Any]:
    z = np.load(npz, allow_pickle=False)
    if "meta_json" not in z.files:
        return {}
    return json.loads(str(z["meta_json"]))


def dataset_name(npz: Path) -> str:
    return npz.stem.replace(".", "_")


def run_one(npz: Path, out_dir: Path, fee_rate: float, policies: str, capitals: str, rebalance: float, lookback: float, resume: bool) -> None:
    dst = out_dir / dataset_name(npz) / f"reb_{rebalance:g}h__look_{lookback:g}h"
    summary = dst / "summary.csv"
    if resume and summary.exists() and summary.stat().st_size > 0:
        return
    cmd = [
        sys.executable,
        str(BACKTESTER),
        "--npz",
        str(npz),
        "--out-dir",
        str(dst),
        "--fee-rates",
        f"metadata:{fee_rate:g}",
        "--capital-grid",
        capitals,
        "--policies",
        policies,
        "--rebalance-hours",
        str(rebalance),
        "--lookback-hours",
        str(lookback),
        "--top-n",
        "200",
    ]
    subprocess.run(cmd, cwd=ROOT, check=True)


def aggregate(out_dir: Path) -> None:
    frames: List[pd.DataFrame] = []
    for p in out_dir.glob("*/*/summary.csv"):
        df = pd.read_csv(p)
        df.insert(0, "dataset", p.parents[1].name)
        df.insert(1, "grid", p.parent.name)
        frames.append(df)
    if not frames:
        return
    all_df = pd.concat(frames, ignore_index=True)
    analysis = out_dir / "_analysis"
    analysis.mkdir(parents=True, exist_ok=True)
    all_df.to_csv(analysis / "all_results.csv", index=False)

    clean = all_df.copy()
    for pat in ["aero_weth", "bio_usdc_03_2026-04-01"]:
        clean = clean[~clean["dataset"].str.contains(pat, regex=False)]
    clean.to_csv(analysis / "all_results_clean.csv", index=False)

    valid = clean[clean["valid_capacity_avg_p95_p99"].astype(bool)]
    strict = valid[(valid["return_pct"] > 0) & (valid["mdd_pct"].abs() <= 25)]
    best = clean.sort_values("score", ascending=False).groupby("dataset", as_index=False).head(1)
    best.to_csv(analysis / "best_by_dataset_score_clean.csv", index=False)

    rows = []
    for ds, g in clean.groupby("dataset"):
        gv = strict[strict["dataset"].eq(ds)]
        b = g.sort_values("score", ascending=False).iloc[0]
        rows.append(
            {
                "dataset": ds,
                "strict_pass": bool(len(gv)),
                "best_strategy": b["strategy"],
                "best_grid": b["grid"],
                "best_capital": b["capital_usd"],
                "best_return_pct": b["return_pct"],
                "best_mdd_pct": b["mdd_pct"],
                "best_score": b["score"],
                "best_p95_share": b["p95_liquidity_share_pct_when_in_range"],
                "best_p99_share": b["p99_liquidity_share_pct_when_in_range"],
                "reason": b.get("policy_reason_top", ""),
            }
        )
    pf = pd.DataFrame(rows).sort_values(["strict_pass", "best_score"], ascending=[True, False])
    pf.to_csv(analysis / "dataset_pass_fail_clean.csv", index=False)

    robust_rows = []
    for (strategy, capital), g in valid.groupby(["strategy", "capital_usd"]):
        robust_rows.append(
            {
                "strategy": strategy,
                "capital": capital,
                "datasets": g["dataset"].nunique(),
                "strict_passes": int(((g["return_pct"] > 0) & (g["mdd_pct"].abs() <= 25)).sum()),
                "positive": int((g["return_pct"] > 0).sum()),
                "median_return": g["return_pct"].median(),
                "p10_return": g["return_pct"].quantile(0.10),
                "min_return": g["return_pct"].min(),
                "median_mdd": g["mdd_pct"].median(),
                "worst_mdd": g["mdd_pct"].min(),
                "max_p99_share": g["p99_liquidity_share_pct_when_in_range"].max(),
                "median_score": g["score"].median(),
            }
        )
    robust = pd.DataFrame(robust_rows).sort_values(["strict_passes", "positive", "median_score"], ascending=False)
    robust.to_csv(analysis / "policy_capital_robustness_clean.csv", index=False)

    print(json.dumps({"analysis": str(analysis), "clean_datasets": int(clean["dataset"].nunique()), "strict_pass_datasets": int(pf["strict_pass"].sum())}, ensure_ascii=False))


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-dir", default=str(ROOT / "DEX_DATA" / "fast_npz"))
    ap.add_argument("--out-dir", default=str(ROOT / "DEX_REPORTS" / "algorithmic_lp_search_v2_pdf" / "cross_pool_grid"))
    ap.add_argument("--npz-glob", default="*.npz")
    ap.add_argument("--policies", default="pdf_vol_gated_lp,pdf_range_order,pdf_toxicity_gated_lp,downtrend_fee_harvester,passive_wide_evidence")
    ap.add_argument("--capital-grid", default="0.25,0.5,1,2,5,10,25")
    ap.add_argument("--rebalance-grid", default="6,12,24,48,72,168,336,672")
    ap.add_argument("--lookback-grid", default="6,12,24,48,72,168")
    ap.add_argument("--resume", action="store_true")
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    npzs = sorted(Path(args.data_dir).glob(args.npz_glob))
    rebalances = [float(x) for x in args.rebalance_grid.split(",") if x.strip()]
    lookbacks = [float(x) for x in args.lookback_grid.split(",") if x.strip()]

    plan = {"datasets": len(npzs), "rebalances": rebalances, "lookbacks": lookbacks, "out_dir": str(out_dir)}
    (out_dir / "plan.json").write_text(json.dumps(plan, indent=2), encoding="utf-8")

    for npz in npzs:
        meta = load_meta(npz)
        fee_rate = float(meta.get("fee_rate", 0.003))
        for rebalance in rebalances:
            for lookback in lookbacks:
                run_one(npz, out_dir, fee_rate, args.policies, args.capital_grid, rebalance, lookback, args.resume)
    aggregate(out_dir)


if __name__ == "__main__":
    main()
