#!/usr/bin/env python3

"""
dex_platform/backtest/cl_fee_replay_fast_npz_v3.py

Capacity-aware fast NPZ fee replay tuner for Aerodrome/Uniswap-v3-like CL LP ranges.

Keeps v2 logic:
- load Swap-only NPZ
- time filtering
- static range / periodic rebalance / out-of-range rebalance
- corrected fee share: our_liquidity / (active_liquidity + our_liquidity)
- separate fee accounting: earned / reinvested / uncollected / rebalance costs

v3 additions:
- capital grid / fixed / auto_cap modes
- avg/p95/p99/max liquidity-share metrics
- capacity-aware score
- deployable capital and return-on-total-capital metrics
- summary/best output files

This is a deterministic replay/tuning tool, not a live trading engine.
"""

import argparse


def _add_bool_arg_compat(ap, name, default=False, help=None):
    """Python 3.8-compatible replacement for argparse.BooleanOptionalAction."""
    dest = name.lstrip("-").replace("-", "_")
    if hasattr(argparse, "BooleanOptionalAction"):
        ap.add_argument(name, action=argparse.BooleanOptionalAction, default=default, help=help)
        return
    group = ap.add_mutually_exclusive_group()
    group.add_argument(name, dest=dest, action="store_true", help=help)
    group.add_argument("--no-" + name.lstrip("-"), dest=dest, action="store_false")
    ap.set_defaults(**{dest: default})


import ctypes
import csv
import hashlib
import json
import multiprocessing as mp
import os
import subprocess
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd

SCRIPT_VERSION = "cl_fee_replay_fast_npz_v3_capacity_2026_05_02"
_WORKER_CTX: Dict[str, Any] = {}
_FAST_CORE: Any = None


@dataclass(frozen=True)
class Strategy:
    name: str
    lower_pct: float
    upper_pct: float
    rebalance_hours: float = 0.0
    gas_usd: float = 0.0
    swap_cost_bps: float = 0.0
    mode: str = "none"  # none | periodic | oor


def parse_iso_ts(s: str) -> int:
    text = str(s).strip()
    if text.endswith("Z"):
        text = text[:-1]
        tz = timezone.utc
    elif "+" in text[10:] or "-" in text[10:]:
        # Python 3.6 safe fallback for offsets like +00:00.
        sign = 1 if "+" in text[10:] else -1
        sep = text.rfind("+") if sign > 0 else text.rfind("-")
        base = text[:sep]
        off = text[sep + 1 :]
        hours, minutes = off.split(":")
        from datetime import timedelta
        tz = timezone(sign * timedelta(hours=int(hours), minutes=int(minutes)))
        text = base
    else:
        tz = timezone.utc
    x = datetime.strptime(text, "%Y-%m-%dT%H:%M:%S")
    x = x.replace(tzinfo=tz).astimezone(timezone.utc)
    return int(x.timestamp())


def parse_float_list(s: str) -> List[float]:
    return [float(x.strip()) for x in str(s).split(",") if x.strip()]


def parse_fee_specs(spec: str) -> List[Tuple[str, float]]:
    out: List[Tuple[str, float]] = []
    for item in str(spec).split(","):
        item = item.strip()
        if not item:
            continue
        if ":" not in item:
            raise ValueError(f"Bad fee spec: {item}; expected name:rate")
        name, rate = item.split(":", 1)
        out.append((name, float(rate)))
    if not out:
        raise ValueError("No fee specs")
    return out


def parse_rebalance_grid(spec: str) -> List[Tuple[str, float]]:
    """Format: none:0,periodic:168,periodic:336,oor:24"""
    out: List[Tuple[str, float]] = []
    for item in str(spec).split(","):
        item = item.strip()
        if not item:
            continue
        if ":" not in item:
            raise ValueError(f"Bad rebalance item: {item}; expected mode:hours")
        mode, hours = item.split(":", 1)
        mode = mode.strip()
        if mode not in {"none", "periodic", "oor"}:
            raise ValueError(f"Bad rebalance mode: {mode}")
        out.append((mode, float(hours)))
    return out or [("none", 0.0)]


def parse_strategy(spec: str) -> Strategy:
    """name:lower:upper[:rebalance_hours[:gas_usd[:swap_cost_bps[:mode]]]]"""
    parts = spec.split(":")
    if len(parts) < 3:
        raise ValueError(f"Bad strategy spec: {spec}")
    reb = float(parts[3]) if len(parts) > 3 and parts[3] else 0.0
    gas = float(parts[4]) if len(parts) > 4 and parts[4] else 0.0
    swap = float(parts[5]) if len(parts) > 5 and parts[5] else 0.0
    mode = str(parts[6]) if len(parts) > 6 and parts[6] else ("periodic" if reb > 0 else "none")
    return Strategy(parts[0], float(parts[1]), float(parts[2]), reb, gas, swap, mode)


def load_npz(path: Any) -> Dict[str, Any]:
    p = Path(path)
    if not p.exists():
        raise SystemExit(f"NPZ not found: {p}")
    z = np.load(p, allow_pickle=False)
    out: Dict[str, Any] = {k: z[k] for k in z.files}
    meta: Dict[str, Any] = {}
    if "meta_json" in out:
        meta = json.loads(str(out["meta_json"]))
    out["meta"] = meta
    return out


def load_fast_core() -> Any:
    global _FAST_CORE
    if _FAST_CORE is not None:
        return _FAST_CORE
    src = Path(__file__).with_name("cl_replay_fast_core.c")
    if not src.exists():
        return None
    cflags = b"-Ofast -march=native"
    digest = hashlib.sha256(src.read_bytes() + cflags).hexdigest()[:16]
    so = Path("/tmp") / f"dex_cl_replay_fast_core_{digest}.so"
    if not so.exists():
        cmd = ["gcc", "-Ofast", "-march=native", "-shared", "-fPIC", str(src), "-lm", "-o", str(so)]
        try:
            subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except Exception:
            return None
    try:
        lib = ctypes.CDLL(str(so))
    except OSError:
        return None

    dbl_p = np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags="C_CONTIGUOUS")
    i64_p = np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, flags="C_CONTIGUOUS")
    lib.replay_many.argtypes = [
        dbl_p, dbl_p, dbl_p, dbl_p, i64_p,
        ctypes.c_int, ctypes.c_int, ctypes.c_int,
        dbl_p, ctypes.c_int,
        ctypes.c_double, ctypes.c_double, ctypes.c_int, ctypes.c_double,
        ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double,
    ] + [dbl_p] * 24
    lib.replay_many.restype = None
    _FAST_CORE = lib
    return _FAST_CORE


def filter_time(data: Dict[str, Any], time_from: str, time_to: str) -> Dict[str, Any]:
    ts = data["ts"].astype(np.int64)
    mask = np.ones(len(ts), dtype=bool)
    if time_from:
        mask &= ts >= parse_iso_ts(time_from)
    if time_to:
        mask &= ts < parse_iso_ts(time_to)
    if not mask.any():
        raise SystemExit(f"Empty time slice: {time_from} -> {time_to}")
    out = dict(data)
    for k in ["ts", "block", "log_index", "tick", "price", "amount0_h", "amount1_h", "input_usd", "active_liquidity"]:
        if k in out:
            out[k] = out[k][mask]
    return out


def sqrt_raw_token1_per_token0_from_price(price_token0_per_token1: Any, dec0: int, dec1: int) -> np.ndarray:
    p = np.asarray(price_token0_per_token1, dtype=np.float64)
    q_raw = (10 ** (dec1 - dec0)) / np.maximum(p, 1e-300)
    return np.sqrt(q_raw)


def amounts_raw_for_liquidity_vec(liquidity_raw: float, price: np.ndarray, lower_price: float, upper_price: float, dec0: int, dec1: int) -> Tuple[np.ndarray, np.ndarray]:
    p = np.asarray(price, dtype=np.float64)
    lo = max(float(lower_price), 1e-300)
    up = max(float(upper_price), lo * 1.000001)

    sqrt_p = sqrt_raw_token1_per_token0_from_price(p, dec0, dec1)
    sqrt_a = float(sqrt_raw_token1_per_token0_from_price(up, dec0, dec1))
    sqrt_b = float(sqrt_raw_token1_per_token0_from_price(lo, dec0, dec1))

    L = float(liquidity_raw)
    amount0 = np.zeros_like(p, dtype=np.float64)
    amount1 = np.zeros_like(p, dtype=np.float64)

    below = p <= lo
    above = p >= up
    mid = ~(below | above)

    amount0[below] = 0.0
    amount1[below] = L * (sqrt_b - sqrt_a)

    amount0[above] = L * (sqrt_b - sqrt_a) / (sqrt_a * sqrt_b)
    amount1[above] = 0.0

    amount0[mid] = L * (sqrt_b - sqrt_p[mid]) / (sqrt_p[mid] * sqrt_b)
    amount1[mid] = L * (sqrt_p[mid] - sqrt_a)
    return amount0, amount1


def value_usd_from_raw(amount0_raw: np.ndarray, amount1_raw: np.ndarray, price: np.ndarray, dec0: int, dec1: int) -> np.ndarray:
    return amount0_raw / (10 ** dec0) + amount1_raw / (10 ** dec1) * price


def unit_value_curve(price: np.ndarray, lower_price: float, upper_price: float, dec0: int, dec1: int) -> np.ndarray:
    a0, a1 = amounts_raw_for_liquidity_vec(1.0, price, lower_price, upper_price, dec0, dec1)
    return value_usd_from_raw(a0, a1, price, dec0, dec1)


def unit_value_curve_from_sqrt(sqrt_p: np.ndarray, price: np.ndarray, lower_price: float, upper_price: float, dec0: int, dec1: int) -> np.ndarray:
    lo = max(float(lower_price), 1e-300)
    up = max(float(upper_price), lo * 1.000001)
    sqrt_a = float(sqrt_raw_token1_per_token0_from_price(up, dec0, dec1))
    sqrt_b = float(sqrt_raw_token1_per_token0_from_price(lo, dec0, dec1))
    inv_dec0 = 10.0 ** (-dec0)
    inv_dec1 = 10.0 ** (-dec1)

    out = np.empty_like(price, dtype=np.float64)
    below = price <= lo
    above = price >= up
    mid = ~(below | above)

    out[below] = (sqrt_b - sqrt_a) * inv_dec1 * price[below]
    out[above] = ((sqrt_b - sqrt_a) / (sqrt_a * sqrt_b)) * inv_dec0
    sp = sqrt_p[mid]
    out[mid] = ((sqrt_b - sp) / (sp * sqrt_b)) * inv_dec0 + (sp - sqrt_a) * inv_dec1 * price[mid]
    return out


def unit_value_one_from_sqrt(sqrt_p: float, price: float, lower_price: float, upper_price: float, dec0: int, dec1: int) -> float:
    lo = max(float(lower_price), 1e-300)
    up = max(float(upper_price), lo * 1.000001)
    sqrt_a = float(sqrt_raw_token1_per_token0_from_price(up, dec0, dec1))
    sqrt_b = float(sqrt_raw_token1_per_token0_from_price(lo, dec0, dec1))
    if price <= lo:
        return float((sqrt_b - sqrt_a) * (10.0 ** (-dec1)) * price)
    if price >= up:
        return float(((sqrt_b - sqrt_a) / (sqrt_a * sqrt_b)) * (10.0 ** (-dec0)))
    return float(((sqrt_b - sqrt_p) / (sqrt_p * sqrt_b)) * (10.0 ** (-dec0)) + (sqrt_p - sqrt_a) * (10.0 ** (-dec1)) * price)


def liquidity_for_capital(capital_usd: float, open_price: float, lower_price: float, upper_price: float, dec0: int, dec1: int) -> float:
    p = np.array([open_price], dtype=np.float64)
    unit_value = float(unit_value_curve(p, lower_price, upper_price, dec0, dec1)[0])
    if unit_value <= 1e-300:
        return 0.0
    return float(capital_usd) / unit_value


def liquidity_share(our_liq: float, active_liq: Any) -> np.ndarray:
    return float(our_liq) / (np.asarray(active_liq, dtype=np.float64) + float(our_liq))


def max_drawdown_pct(equity: np.ndarray) -> float:
    if len(equity) == 0:
        return 0.0
    peak = np.maximum.accumulate(equity)
    dd = equity / np.where(peak == 0, np.nan, peak) - 1.0
    return float(np.nanmin(dd) * 100.0)


def share_stats_pct(share_in: np.ndarray) -> Tuple[float, float, float, float]:
    if len(share_in) == 0:
        return 0.0, 0.0, 0.0, 0.0
    s = np.asarray(share_in, dtype=np.float64) * 100.0
    return float(np.mean(s)), float(np.percentile(s, 95)), float(np.percentile(s, 99)), float(np.max(s))


def base_summary(strategy: Strategy, initial_capital: float) -> Dict[str, Any]:
    return {
        "strategy": strategy.name,
        "lower_pct": float(strategy.lower_pct),
        "upper_pct": float(strategy.upper_pct),
        "rebalance_hours": float(strategy.rebalance_hours),
        "rebalance_mode": strategy.mode,
        "capital_usd": float(initial_capital),
        "initial_capital_usd": float(initial_capital),
    }


def finish_summary(row: Dict[str, Any], equity: np.ndarray, pos_value: np.ndarray, fee_total: float, fees_reinvested: float, fees_uncollected_end: float, rebalance_costs: float, in_range: np.ndarray, share: np.ndarray, hodl50: np.ndarray, price: np.ndarray, initial_capital: float, rebalances: int, total_capital_usd: float) -> Dict[str, Any]:
    share_in = share[in_range.astype(bool)]
    avg_s, p95_s, p99_s, max_s = share_stats_pct(share_in)
    profit = float(equity[-1] - initial_capital)
    row.update({
        "equity_end_usd": float(equity[-1]),
        "return_pct": float((equity[-1] / initial_capital - 1.0) * 100.0),
        "mdd_pct": max_drawdown_pct(equity),
        "fees_earned_total": float(fee_total),
        "fees_reinvested": float(fees_reinvested),
        "fees_uncollected_end": float(fees_uncollected_end),
        "rebalance_costs": float(rebalance_costs),
        "position_value_end_usd": float(pos_value[-1]),
        "time_in_range_pct": float(np.asarray(in_range, dtype=bool).mean() * 100.0),
        "avg_liquidity_share_pct_when_in_range": avg_s,
        "p95_liquidity_share_pct_when_in_range": p95_s,
        "p99_liquidity_share_pct_when_in_range": p99_s,
        "max_liquidity_share_pct_when_in_range": max_s,
        "rebalances": int(rebalances),
        "hodl50_return_pct": float((hodl50[-1] / initial_capital - 1.0) * 100.0),
        "vs_hodl50_usd": float(equity[-1] - hodl50[-1]),
        "price_start": float(price[0]),
        "price_end": float(price[-1]),
        "price_return_pct": float((price[-1] / price[0] - 1.0) * 100.0),
        "profit_usd": profit,
        "deployable_capital_usd": float(initial_capital),
        "unused_capital_usd": float(max(0.0, total_capital_usd - initial_capital)),
        "return_on_deployed_capital_pct": float((profit / initial_capital) * 100.0),
        "return_on_total_capital_pct": float((profit / total_capital_usd) * 100.0) if total_capital_usd else np.nan,
    })
    return row


def finish_summary_many(strategy: Strategy, capitals: np.ndarray, equity: np.ndarray, pos_value: np.ndarray, fee_total: np.ndarray, fees_reinvested: np.ndarray, fees_uncollected_end: np.ndarray, rebalance_costs: np.ndarray, in_range: np.ndarray, share: np.ndarray, hodl50: np.ndarray, price: np.ndarray, rebalances: int, total_capital_usd: float) -> List[Dict[str, Any]]:
    mask = in_range.astype(bool)
    if mask.any():
        share_in = share[:, mask] * 100.0
        avg_s = np.mean(share_in, axis=1)
        p95_s = np.percentile(share_in, 95, axis=1)
        p99_s = np.percentile(share_in, 99, axis=1)
        max_s = np.max(share_in, axis=1)
    else:
        avg_s = np.zeros(len(capitals), dtype=np.float64)
        p95_s = np.zeros(len(capitals), dtype=np.float64)
        p99_s = np.zeros(len(capitals), dtype=np.float64)
        max_s = np.zeros(len(capitals), dtype=np.float64)

    peak = np.maximum.accumulate(equity, axis=1)
    dd = equity / np.where(peak == 0, np.nan, peak) - 1.0
    mdd = np.nanmin(dd, axis=1) * 100.0
    profit = equity[:, -1] - capitals
    time_in_range = float(mask.mean() * 100.0)
    price_return = float((price[-1] / price[0] - 1.0) * 100.0)

    rows: List[Dict[str, Any]] = []
    for i, cap in enumerate(capitals):
        row = base_summary(strategy, float(cap))
        row.update({
            "equity_end_usd": float(equity[i, -1]),
            "return_pct": float((equity[i, -1] / cap - 1.0) * 100.0),
            "mdd_pct": float(mdd[i]),
            "fees_earned_total": float(fee_total[i]),
            "fees_reinvested": float(fees_reinvested[i]),
            "fees_uncollected_end": float(fees_uncollected_end[i]),
            "rebalance_costs": float(rebalance_costs[i]),
            "position_value_end_usd": float(pos_value[i, -1]),
            "time_in_range_pct": time_in_range,
            "avg_liquidity_share_pct_when_in_range": float(avg_s[i]),
            "p95_liquidity_share_pct_when_in_range": float(p95_s[i]),
            "p99_liquidity_share_pct_when_in_range": float(p99_s[i]),
            "max_liquidity_share_pct_when_in_range": float(max_s[i]),
            "rebalances": int(rebalances),
            "hodl50_return_pct": float((hodl50[i, -1] / cap - 1.0) * 100.0),
            "vs_hodl50_usd": float(equity[i, -1] - hodl50[i, -1]),
            "price_start": float(price[0]),
            "price_end": float(price[-1]),
            "price_return_pct": price_return,
            "profit_usd": float(profit[i]),
            "deployable_capital_usd": float(cap),
            "unused_capital_usd": float(max(0.0, total_capital_usd - cap)),
            "return_on_deployed_capital_pct": float((profit[i] / cap) * 100.0),
            "return_on_total_capital_pct": float((profit[i] / total_capital_usd) * 100.0) if total_capital_usd else np.nan,
        })
        rows.append(row)
    return rows


def static_backtest(price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, initial_capital: float, strategy: Strategy, fee_rate: float, total_capital_usd: float, want_curve: bool = False) -> Tuple[Dict[str, Any], Optional[Any]]:
    p0 = float(price[0])
    lower = p0 * (1.0 - strategy.lower_pct / 100.0)
    upper = p0 * (1.0 + strategy.upper_pct / 100.0)
    our_liq = liquidity_for_capital(initial_capital, p0, lower, upper, dec0, dec1)

    in_range = (price >= lower) & (price <= upper)
    share = liquidity_share(our_liq, active_liq)
    fee_events = np.zeros_like(price, dtype=np.float64)
    fee_events[in_range] = input_usd[in_range] * fee_rate * share[in_range]
    fees_cum = np.cumsum(fee_events)

    a0, a1 = amounts_raw_for_liquidity_vec(our_liq, price, lower, upper, dec0, dec1)
    pos_value = value_usd_from_raw(a0, a1, price, dec0, dec1)
    equity = pos_value + fees_cum
    hodl50 = initial_capital / 2.0 + (initial_capital / 2.0 / p0) * price

    row = finish_summary(base_summary(strategy, initial_capital), equity, pos_value, float(fees_cum[-1]), 0.0, float(fees_cum[-1]), 0.0, in_range.astype(np.int8), share, hodl50, price, initial_capital, 0, total_capital_usd)
    curve = None
    if want_curve:
        import pandas as pd
        curve = pd.DataFrame({
            "timestamp": ts,
            "datetime_utc": pd.to_datetime(ts, unit="s", utc=True),
            "price": price,
            "equity": equity,
            "position_value": pos_value,
            "fees_earned_total": fees_cum,
            "fees_uncollected": fees_cum,
            "in_range": in_range.astype(np.int8),
            "liquidity_share_pct": share * 100.0,
            "hodl50": hodl50,
            "lower_price": lower,
            "upper_price": upper,
        })
    return row, curve


def static_backtest_many(price: np.ndarray, sqrt_price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, capitals: np.ndarray, strategy: Strategy, fee_rate: float, total_capital_usd: float) -> List[Dict[str, Any]]:
    p0 = float(price[0])
    lower = p0 * (1.0 - strategy.lower_pct / 100.0)
    upper = p0 * (1.0 + strategy.upper_pct / 100.0)
    unit = unit_value_curve_from_sqrt(sqrt_price, price, lower, upper, dec0, dec1)
    unit0 = max(float(unit[0]), 1e-300)
    liq = capitals / unit0

    in_range = (price >= lower) & (price <= upper)
    share = liq[:, None] / (active_liq[None, :] + liq[:, None])
    fee_events = np.where(in_range[None, :], input_usd[None, :] * fee_rate * share, 0.0)
    fees_cum = np.cumsum(fee_events, axis=1)
    pos_value = liq[:, None] * unit[None, :]
    equity = pos_value + fees_cum
    hodl50 = capitals[:, None] / 2.0 + (capitals[:, None] / 2.0 / p0) * price[None, :]

    return finish_summary_many(
        strategy,
        capitals,
        equity,
        pos_value,
        fees_cum[:, -1],
        np.zeros(len(capitals), dtype=np.float64),
        fees_cum[:, -1],
        np.zeros(len(capitals), dtype=np.float64),
        in_range.astype(np.int8),
        share,
        hodl50,
        price,
        0,
        total_capital_usd,
    )


def periodic_backtest(price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, initial_capital: float, strategy: Strategy, fee_rate: float, total_capital_usd: float, want_curve: bool = False) -> Tuple[Dict[str, Any], Optional[Any]]:
    n = len(price)
    p0 = float(price[0])
    lower = p0 * (1.0 - strategy.lower_pct / 100.0)
    upper = p0 * (1.0 + strategy.upper_pct / 100.0)
    capital = float(initial_capital)
    our_liq = liquidity_for_capital(capital, p0, lower, upper, dec0, dec1)

    last_reb_ts = int(ts[0])
    interval = int(strategy.rebalance_hours * 3600)
    fees_uncollected = 0.0
    fees_earned_total = 0.0
    fees_reinvested = 0.0
    costs_cum = 0.0
    rebalances = 0

    equity_arr = np.empty(n, dtype=np.float64)
    pos_arr = np.empty(n, dtype=np.float64)
    fees_total_arr = np.empty(n, dtype=np.float64)
    fees_uncol_arr = np.empty(n, dtype=np.float64)
    in_arr = np.zeros(n, dtype=np.int8)
    share_arr = np.zeros(n, dtype=np.float64)
    hodl50 = initial_capital / 2.0 + (initial_capital / 2.0 / p0) * price

    idx = 0
    while idx < n:
        if interval <= 0:
            next_idx = n
        else:
            start_reb = int(np.searchsorted(ts, last_reb_ts + interval, side="left"))
            if start_reb >= n:
                next_idx = n
            elif strategy.mode == "oor":
                out_mask = (price[start_reb:] < lower) | (price[start_reb:] > upper)
                next_idx = start_reb + int(np.argmax(out_mask)) if out_mask.any() else n
            else:
                next_idx = start_reb

        if next_idx > idx:
            sl = slice(idx, next_idx)
            pseg = price[sl]
            in_range = (pseg >= lower) & (pseg <= upper)
            share = liquidity_share(our_liq, active_liq[sl])
            fee_events = np.zeros_like(pseg, dtype=np.float64)
            fee_events[in_range] = input_usd[sl][in_range] * fee_rate * share[in_range]
            fees_cum = np.cumsum(fee_events)
            a0, a1 = amounts_raw_for_liquidity_vec(our_liq, pseg, lower, upper, dec0, dec1)
            pos_value = value_usd_from_raw(a0, a1, pseg, dec0, dec1)

            equity_arr[sl] = pos_value + fees_uncollected + fees_cum
            pos_arr[sl] = pos_value
            fees_total_arr[sl] = fees_earned_total + fees_cum
            fees_uncol_arr[sl] = fees_uncollected + fees_cum
            in_arr[sl] = in_range.astype(np.int8)
            share_arr[sl] = share

            segment_fees = float(fees_cum[-1]) if len(fees_cum) else 0.0
            fees_uncollected += segment_fees
            fees_earned_total += segment_fees
            idx = next_idx

        if idx >= n:
            break

        # Rebalance before processing current swap event, matching v2 behavior.
        p = float(price[idx])
        a0, a1 = amounts_raw_for_liquidity_vec(our_liq, np.array([p]), lower, upper, dec0, dec1)
        pos_val = float(value_usd_from_raw(a0, a1, np.array([p]), dec0, dec1)[0])
        redeploy = pos_val + fees_uncollected
        cost = strategy.gas_usd + redeploy * (strategy.swap_cost_bps / 10000.0)
        fees_reinvested += fees_uncollected
        fees_uncollected = 0.0
        costs_cum += cost
        redeploy = max(0.0, redeploy - cost)

        capital = redeploy
        lower = p * (1.0 - strategy.lower_pct / 100.0)
        upper = p * (1.0 + strategy.upper_pct / 100.0)
        our_liq = liquidity_for_capital(capital, p, lower, upper, dec0, dec1)
        last_reb_ts = int(ts[idx])
        rebalances += 1

    row = finish_summary(base_summary(strategy, initial_capital), equity_arr, pos_arr, fees_earned_total, fees_reinvested, fees_uncollected, costs_cum, in_arr, share_arr, hodl50, price, initial_capital, rebalances, total_capital_usd)
    curve = None
    if want_curve:
        import pandas as pd
        curve = pd.DataFrame({
            "timestamp": ts,
            "datetime_utc": pd.to_datetime(ts, unit="s", utc=True),
            "price": price,
            "equity": equity_arr,
            "position_value": pos_arr,
            "fees_earned_total": fees_total_arr,
            "fees_uncollected": fees_uncol_arr,
            "in_range": in_arr,
            "liquidity_share_pct": share_arr * 100.0,
            "hodl50": hodl50,
        })
    return row, curve


def periodic_backtest_many(price: np.ndarray, sqrt_price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, capitals: np.ndarray, strategy: Strategy, fee_rate: float, total_capital_usd: float) -> List[Dict[str, Any]]:
    n = len(price)
    c = len(capitals)
    p0 = float(price[0])
    lower = p0 * (1.0 - strategy.lower_pct / 100.0)
    upper = p0 * (1.0 + strategy.upper_pct / 100.0)
    unit0 = max(unit_value_one_from_sqrt(float(sqrt_price[0]), p0, lower, upper, dec0, dec1), 1e-300)
    capital = capitals.astype(np.float64).copy()
    our_liq = capital / unit0

    last_reb_ts = int(ts[0])
    interval = int(strategy.rebalance_hours * 3600)
    fees_uncollected = np.zeros(c, dtype=np.float64)
    fees_earned_total = np.zeros(c, dtype=np.float64)
    fees_reinvested = np.zeros(c, dtype=np.float64)
    costs_cum = np.zeros(c, dtype=np.float64)
    rebalances = 0

    equity_arr = np.empty((c, n), dtype=np.float64)
    pos_arr = np.empty((c, n), dtype=np.float64)
    in_arr = np.zeros(n, dtype=np.int8)
    share_arr = np.zeros((c, n), dtype=np.float64)
    hodl50 = capitals[:, None] / 2.0 + (capitals[:, None] / 2.0 / p0) * price[None, :]

    idx = 0
    while idx < n:
        if interval <= 0:
            next_idx = n
        else:
            start_reb = int(np.searchsorted(ts, last_reb_ts + interval, side="left"))
            if start_reb >= n:
                next_idx = n
            elif strategy.mode == "oor":
                out_mask = (price[start_reb:] < lower) | (price[start_reb:] > upper)
                next_idx = start_reb + int(np.argmax(out_mask)) if out_mask.any() else n
            else:
                next_idx = start_reb

        if next_idx > idx:
            sl = slice(idx, next_idx)
            pseg = price[sl]
            in_range = (pseg >= lower) & (pseg <= upper)
            share = our_liq[:, None] / (active_liq[None, sl] + our_liq[:, None])
            fee_events = np.where(in_range[None, :], input_usd[None, sl] * fee_rate * share, 0.0)
            fees_cum = np.cumsum(fee_events, axis=1)
            unit = unit_value_curve_from_sqrt(sqrt_price[sl], pseg, lower, upper, dec0, dec1)
            pos_value = our_liq[:, None] * unit[None, :]

            equity_arr[:, sl] = pos_value + fees_uncollected[:, None] + fees_cum
            pos_arr[:, sl] = pos_value
            in_arr[sl] = in_range.astype(np.int8)
            share_arr[:, sl] = share

            segment_fees = fees_cum[:, -1] if fees_cum.shape[1] else np.zeros(c, dtype=np.float64)
            fees_uncollected += segment_fees
            fees_earned_total += segment_fees
            idx = next_idx

        if idx >= n:
            break

        p = float(price[idx])
        unit_reb = max(unit_value_one_from_sqrt(float(sqrt_price[idx]), p, lower, upper, dec0, dec1), 0.0)
        pos_val = our_liq * unit_reb
        redeploy = pos_val + fees_uncollected
        cost = strategy.gas_usd + redeploy * (strategy.swap_cost_bps / 10000.0)
        fees_reinvested += fees_uncollected
        fees_uncollected[:] = 0.0
        costs_cum += cost
        capital = np.maximum(0.0, redeploy - cost)

        lower = p * (1.0 - strategy.lower_pct / 100.0)
        upper = p * (1.0 + strategy.upper_pct / 100.0)
        unit_new = max(unit_value_one_from_sqrt(float(sqrt_price[idx]), p, lower, upper, dec0, dec1), 1e-300)
        our_liq = capital / unit_new
        last_reb_ts = int(ts[idx])
        rebalances += 1

    return finish_summary_many(strategy, capitals, equity_arr, pos_arr, fees_earned_total, fees_reinvested, fees_uncollected, costs_cum, in_arr, share_arr, hodl50, price, rebalances, total_capital_usd)


def fast_core_backtest_many(price: np.ndarray, sqrt_price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, capitals: np.ndarray, strategy: Strategy, fee_rate: float, total_capital_usd: float) -> Optional[List[Dict[str, Any]]]:
    lib = load_fast_core()
    if lib is None:
        return None
    price = np.ascontiguousarray(price, dtype=np.float64)
    sqrt_price = np.ascontiguousarray(sqrt_price, dtype=np.float64)
    input_usd = np.ascontiguousarray(input_usd, dtype=np.float64)
    active_liq = np.ascontiguousarray(active_liq, dtype=np.float64)
    ts = np.ascontiguousarray(ts, dtype=np.int64)
    capitals = np.ascontiguousarray(capitals, dtype=np.float64)
    c = len(capitals)
    outs = [np.empty(c, dtype=np.float64) for _ in range(24)]
    mode = 0 if strategy.mode == "none" or strategy.rebalance_hours <= 0 else (2 if strategy.mode == "oor" else 1)
    lib.replay_many(
        price, sqrt_price, input_usd, active_liq, ts,
        len(price), dec0, dec1,
        capitals, c,
        float(strategy.lower_pct), float(strategy.upper_pct), mode, float(strategy.rebalance_hours),
        float(strategy.gas_usd), float(strategy.swap_cost_bps), float(fee_rate), float(total_capital_usd),
        *outs,
    )

    keys = [
        "equity_end_usd", "return_pct", "mdd_pct", "fees_earned_total",
        "fees_reinvested", "fees_uncollected_end", "rebalance_costs",
        "position_value_end_usd", "time_in_range_pct",
        "avg_liquidity_share_pct_when_in_range",
        "p95_liquidity_share_pct_when_in_range",
        "p99_liquidity_share_pct_when_in_range",
        "max_liquidity_share_pct_when_in_range", "rebalances",
        "hodl50_return_pct", "vs_hodl50_usd", "price_start", "price_end",
        "price_return_pct", "profit_usd", "deployable_capital_usd",
        "unused_capital_usd", "return_on_deployed_capital_pct",
        "return_on_total_capital_pct",
    ]
    rows: List[Dict[str, Any]] = []
    for i, cap in enumerate(capitals):
        row = base_summary(strategy, float(cap))
        for key, arr in zip(keys, outs):
            row[key] = int(arr[i]) if key == "rebalances" else float(arr[i])
        rows.append(row)
    return rows


def _init_worker(price: np.ndarray, sqrt_price: np.ndarray, input_usd: np.ndarray, active_liq: np.ndarray, ts: np.ndarray, dec0: int, dec1: int, capitals: np.ndarray, total_capital_usd: float, use_fast_core: bool) -> None:
    _WORKER_CTX.clear()
    _WORKER_CTX.update({
        "price": price,
        "sqrt_price": sqrt_price,
        "input_usd": input_usd,
        "active_liq": active_liq,
        "ts": ts,
        "dec0": dec0,
        "dec1": dec1,
        "capitals": capitals,
        "total_capital_usd": total_capital_usd,
        "use_fast_core": use_fast_core,
    })


def _run_strategy_worker(task: Tuple[str, float, Strategy]) -> Tuple[str, float, List[Dict[str, Any]]]:
    fee_name, fee_rate, st = task
    ctx = _WORKER_CTX
    if ctx.get("use_fast_core", False):
        rows = fast_core_backtest_many(ctx["price"], ctx["sqrt_price"], ctx["input_usd"], ctx["active_liq"], ctx["ts"], ctx["dec0"], ctx["dec1"], ctx["capitals"], st, fee_rate, ctx["total_capital_usd"])
        if rows is not None:
            return fee_name, fee_rate, rows
    if st.mode == "none" or st.rebalance_hours <= 0:
        rows = static_backtest_many(ctx["price"], ctx["sqrt_price"], ctx["input_usd"], ctx["active_liq"], ctx["ts"], ctx["dec0"], ctx["dec1"], ctx["capitals"], st, fee_rate, ctx["total_capital_usd"])
    else:
        rows = periodic_backtest_many(ctx["price"], ctx["sqrt_price"], ctx["input_usd"], ctx["active_liq"], ctx["ts"], ctx["dec0"], ctx["dec1"], ctx["capitals"], st, fee_rate, ctx["total_capital_usd"])
    return fee_name, fee_rate, rows


def score_row(row: Dict[str, Any], args: argparse.Namespace) -> float:
    ret = float(row["return_pct"])
    mdd_abs = abs(float(row["mdd_pct"]))
    avg_s = float(row["avg_liquidity_share_pct_when_in_range"])
    p95_s = float(row["p95_liquidity_share_pct_when_in_range"])
    p99_s = float(row["p99_liquidity_share_pct_when_in_range"])
    max_s = float(row["max_liquidity_share_pct_when_in_range"])
    rebalances = float(row.get("rebalances", 0))

    score = ret
    score -= args.w_mdd * max(0.0, mdd_abs - args.target_mdd_pct)
    score -= args.w_avg_share * max(0.0, avg_s - args.max_avg_liquidity_share_pct)
    score -= args.w_p95_share * max(0.0, p95_s - args.max_p95_liquidity_share_pct)
    score -= args.w_p99_share * max(0.0, p99_s - args.max_p99_liquidity_share_pct)
    score -= args.w_max_share * max(0.0, max_s - args.max_liquidity_share_pct)
    score -= args.w_rebalance * rebalances
    return float(score)


def is_valid_capacity(row: Dict[str, Any], args: argparse.Namespace) -> bool:
    return (
        float(row["avg_liquidity_share_pct_when_in_range"]) <= args.max_avg_liquidity_share_pct
        and float(row["p95_liquidity_share_pct_when_in_range"]) <= args.max_p95_liquidity_share_pct
        and float(row["p99_liquidity_share_pct_when_in_range"]) <= args.max_p99_liquidity_share_pct
    )


def make_strategies(args: argparse.Namespace) -> List[Strategy]:
    strategies: List[Strategy] = []
    if args.strategies:
        strategies.extend(parse_strategy(x) for x in args.strategies.split(",") if x.strip())
    if args.grid_lower and args.grid_upper:
        for lo in parse_float_list(args.grid_lower):
            for up in parse_float_list(args.grid_upper):
                for mode, reb_h in parse_rebalance_grid(args.rebalance_grid):
                    name = f"wide_{lo:g}_{up:g}" if mode == "none" else f"{mode}_{lo:g}_{up:g}_{reb_h:g}h"
                    strategies.append(Strategy(name=name, lower_pct=lo, upper_pct=up, rebalance_hours=reb_h, gas_usd=args.gas_usd, swap_cost_bps=args.swap_cost_bps, mode=mode))
    # de-dupe
    seen = set()
    out = []
    for s in strategies:
        key = (s.name, s.lower_pct, s.upper_pct, s.rebalance_hours, s.mode, s.gas_usd, s.swap_cost_bps)
        if key not in seen:
            seen.add(key)
            out.append(s)
    if not out:
        raise SystemExit("No strategies. Use --strategies or --grid-lower/--grid-upper.")
    return out


def make_capitals(args: argparse.Namespace) -> List[float]:
    if args.capital_mode == "fixed":
        return [float(args.initial_capital_usd)]
    vals = parse_float_list(args.capital_grid)
    if not vals:
        raise SystemExit("Empty --capital-grid")
    return vals


def write_rows_csv(path: Path, rows: List[Dict[str, Any]]) -> None:
    if not rows:
        path.write_text("", encoding="utf-8")
        return
    fields = list(rows[0].keys())
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        w.writerows(rows)


def build_summary_rows(rows: List[Dict[str, Any]], args: argparse.Namespace) -> List[Dict[str, Any]]:
    if args.capital_mode != "auto_cap":
        return sorted(rows, key=lambda r: float(r["score"]), reverse=True)

    valid = [r for r in rows if bool(r["valid_capacity_avg_p95_p99"])]
    source = valid or rows
    chosen: Dict[Tuple[Any, ...], Dict[str, Any]] = {}
    for r in source:
        key = (r["fee_scenario"], r["strategy"], r["lower_pct"], r["upper_pct"], r["rebalance_mode"], r["rebalance_hours"])
        prev = chosen.get(key)
        if prev is None or (float(r["capital_usd"]), float(r["score"])) > (float(prev["capital_usd"]), float(prev["score"])):
            chosen[key] = r
    return sorted(chosen.values(), key=lambda r: float(r["score"]), reverse=True)


def write_outputs_fast(rows: List[Dict[str, Any]], args: argparse.Namespace, out_dir: Path, suffix: str, price_len: int, time_from: str, time_to: str) -> List[Dict[str, Any]]:
    summary = build_summary_rows(rows, args)
    if args.capital_mode == "auto_cap":
        write_rows_csv(out_dir / f"capital_grid_all{suffix}.csv", rows)
    write_rows_csv(out_dir / f"summary{suffix}.csv", summary)
    write_rows_csv(out_dir / f"best_by_score{suffix}.csv", sorted(summary, key=lambda r: float(r["score"]), reverse=True)[:args.top_n])
    write_rows_csv(out_dir / f"best_by_return{suffix}.csv", sorted(summary, key=lambda r: float(r["return_pct"]), reverse=True)[:args.top_n])

    valid_summary = [r for r in summary if bool(r["valid_capacity_avg_p95_p99"])]
    if valid_summary:
        best: Dict[Tuple[Any, ...], Dict[str, Any]] = {}
        for r in sorted(valid_summary, key=lambda x: float(x["score"]), reverse=True):
            key = (r["strategy"], r["lower_pct"], r["upper_pct"], r["rebalance_mode"], r["rebalance_hours"])
            best.setdefault(key, r)
        write_rows_csv(out_dir / f"best_capacity_by_strategy{suffix}.csv", sorted(best.values(), key=lambda r: float(r["score"]), reverse=True))

    result = {
        "script_version": SCRIPT_VERSION,
        "npz": str(args.npz),
        "time_from": time_from,
        "time_to": time_to,
        "rows": int(price_len),
        "capital_mode": args.capital_mode,
        "summary_csv": str(out_dir / f"summary{suffix}.csv"),
        "best_by_score": sorted(summary, key=lambda r: float(r["score"]), reverse=True)[:20],
        "best_by_return": sorted(summary, key=lambda r: float(r["return_pct"]), reverse=True)[:20],
    }
    (out_dir / f"summary{suffix}.json").write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
    return summary


def run_once(data: Dict[str, Any], args: argparse.Namespace, out_dir: Path, time_from: str, time_to: str, suffix: str = "", return_frame: bool = True) -> Any:
    d = filter_time(data, time_from, time_to)
    meta = d.get("meta", {})
    dec0 = args.dec0 or int(meta.get("dec0", 6))
    dec1 = args.dec1 or int(meta.get("dec1", 18))

    ts = d["ts"].astype(np.int64)
    price = d["price"].astype(np.float64)
    sqrt_price = sqrt_raw_token1_per_token0_from_price(price, dec0, dec1)
    input_usd = d["input_usd"].astype(np.float64)
    active_liq = d["active_liquidity"].astype(np.float64)

    strategies = make_strategies(args)
    capitals = make_capitals(args)
    fee_specs = parse_fee_specs(args.fee_rates)

    rows: List[Dict[str, Any]] = []
    curves: List[pd.DataFrame] = []
    want_any_curve = bool(args.plots)

    def add_batch_rows(fee_name: str, fee_rate: float, batch_rows: List[Dict[str, Any]]) -> None:
        for row in batch_rows:
            cap = float(row["capital_usd"])
            row.update({
                "period_suffix": suffix,
                "time_from": time_from,
                "time_to": time_to,
                "fee_scenario": fee_name,
                "fee_rate": fee_rate,
                "run_name": f"{fee_name}__{row['strategy']}__cap_{cap:g}",
                "script_version": SCRIPT_VERSION,
            })
            row["valid_capacity_avg_p95_p99"] = is_valid_capacity(row, args)
            row["valid_capacity_plus_max"] = bool(row["valid_capacity_avg_p95_p99"] and row["max_liquidity_share_pct_when_in_range"] <= args.max_liquidity_share_pct)
            row["score"] = score_row(row, args)
            rows.append(row)

    cap_arr = np.asarray(capitals, dtype=np.float64)
    tasks = [(fee_name, fee_rate, st) for fee_name, fee_rate in fee_specs for st in strategies]
    jobs = int(args.jobs)
    if jobs <= 0:
        jobs = min(len(tasks), max(1, os.cpu_count() or 1))
    use_parallel = (not want_any_curve) and len(capitals) > 1 and jobs > 1 and len(tasks) > 1
    use_fast_core = bool(args.fast_core) and not want_any_curve and len(capitals) > 1

    if use_parallel:
        try:
            mp_ctx = mp.get_context("fork")
        except ValueError:
            mp_ctx = mp.get_context()
        with mp_ctx.Pool(
            processes=jobs,
            initializer=_init_worker,
            initargs=(price, sqrt_price, input_usd, active_liq, ts, dec0, dec1, cap_arr, args.total_capital_usd, use_fast_core),
        ) as pool:
            chunksize = 1 if use_fast_core else max(1, len(tasks) // (jobs * 4))
            for fee_name, fee_rate, batch_rows in pool.imap_unordered(_run_strategy_worker, tasks, chunksize=chunksize):
                add_batch_rows(fee_name, fee_rate, batch_rows)
    else:
        _init_worker(price, sqrt_price, input_usd, active_liq, ts, dec0, dec1, cap_arr, args.total_capital_usd, use_fast_core)
        for fee_name, fee_rate in fee_specs:
            for st in strategies:
                if not want_any_curve and len(capitals) > 1:
                    _, _, batch_rows = _run_strategy_worker((fee_name, fee_rate, st))
                    add_batch_rows(fee_name, fee_rate, batch_rows)
                    continue

                for cap in capitals:
                    want_curve = want_any_curve and len(strategies) * len(capitals) <= args.max_plot_runs
                    if st.mode == "none" or st.rebalance_hours <= 0:
                        row, curve = static_backtest(price, input_usd, active_liq, ts, dec0, dec1, cap, st, fee_rate, args.total_capital_usd, want_curve)
                    else:
                        row, curve = periodic_backtest(price, input_usd, active_liq, ts, dec0, dec1, cap, st, fee_rate, args.total_capital_usd, want_curve)
                    row.update({
                        "period_suffix": suffix,
                        "time_from": time_from,
                        "time_to": time_to,
                        "fee_scenario": fee_name,
                        "fee_rate": fee_rate,
                        "run_name": f"{fee_name}__{st.name}__cap_{cap:g}",
                        "script_version": SCRIPT_VERSION,
                    })
                    row["valid_capacity_avg_p95_p99"] = is_valid_capacity(row, args)
                    row["valid_capacity_plus_max"] = bool(row["valid_capacity_avg_p95_p99"] and row["max_liquidity_share_pct_when_in_range"] <= args.max_liquidity_share_pct)
                    row["score"] = score_row(row, args)
                    rows.append(row)
                    if curve is not None:
                        curve["run_name"] = row["run_name"]
                        curve["fee_scenario"] = fee_name
                        curves.append(curve)

    if not return_frame and not curves:
        write_outputs_fast(rows, args, out_dir, suffix, len(price), time_from, time_to)
        return None

    import pandas as pd
    all_df = pd.DataFrame(rows)
    if args.capital_mode == "auto_cap":
        all_df.to_csv(out_dir / f"capital_grid_all{suffix}.csv", index=False)
        valid = all_df[all_df["valid_capacity_avg_p95_p99"]].copy()
        if len(valid):
            # Maximum deployable capital per strategy; tie-break by score.
            sort_cols = ["capital_usd", "score"]
            chosen = valid.sort_values(sort_cols, ascending=[False, False]).groupby(["fee_scenario", "strategy", "lower_pct", "upper_pct", "rebalance_mode", "rebalance_hours"], as_index=False).head(1)
            summary = chosen.sort_values("score", ascending=False).reset_index(drop=True)
        else:
            summary = all_df.sort_values("score", ascending=False).reset_index(drop=True)
    else:
        summary = all_df.sort_values("score", ascending=False).reset_index(drop=True)

    summary.to_csv(out_dir / f"summary{suffix}.csv", index=False)
    summary.sort_values("score", ascending=False).head(args.top_n).to_csv(out_dir / f"best_by_score{suffix}.csv", index=False)
    summary.sort_values("return_pct", ascending=False).head(args.top_n).to_csv(out_dir / f"best_by_return{suffix}.csv", index=False)

    valid_summary = summary[summary["valid_capacity_avg_p95_p99"]].copy()
    if len(valid_summary):
        best_capacity = valid_summary.sort_values("score", ascending=False).groupby(["strategy", "lower_pct", "upper_pct", "rebalance_mode", "rebalance_hours"], as_index=False).head(1)
        best_capacity.sort_values("score", ascending=False).to_csv(out_dir / f"best_capacity_by_strategy{suffix}.csv", index=False)

    if curves:
        curves_df = pd.concat(curves, ignore_index=True)
        curves_df.to_csv(out_dir / f"curves{suffix}.csv", index=False)
        make_plots(curves_df, summary, out_dir, suffix)

    result = {
        "script_version": SCRIPT_VERSION,
        "npz": str(args.npz),
        "time_from": time_from,
        "time_to": time_to,
        "rows": int(len(price)),
        "capital_mode": args.capital_mode,
        "summary_csv": str(out_dir / f"summary{suffix}.csv"),
        "best_by_score": summary.sort_values("score", ascending=False).head(20).to_dict(orient="records"),
        "best_by_return": summary.sort_values("return_pct", ascending=False).head(20).to_dict(orient="records"),
    }
    (out_dir / f"summary{suffix}.json").write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
    return summary


def make_plots(curves: pd.DataFrame, summary: pd.DataFrame, out_dir: Path, suffix: str) -> None:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    plot_dir = out_dir / "plots"
    plot_dir.mkdir(parents=True, exist_ok=True)

    fig, ax = plt.subplots(figsize=(12, 5))
    for run_name, g in curves.groupby("run_name"):
        ax.plot(g["datetime_utc"], g["equity"], label=run_name)
    first = curves[curves["run_name"] == curves["run_name"].iloc[0]]
    ax.plot(first["datetime_utc"], first["hodl50"], label="hodl50", linestyle="--")
    ax.set_title("Fast NPZ fee replay v3 equity")
    ax.set_ylabel("USD")
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=7)
    fig.tight_layout()
    fig.savefig(plot_dir / f"equity_top_score{suffix}.png", dpi=160, bbox_inches="tight")
    plt.close(fig)

    if "capital_usd" in summary.columns:
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.scatter(summary["capital_usd"], summary["return_pct"])
        ax.set_title("Return vs capital")
        ax.set_xlabel("capital_usd")
        ax.set_ylabel("return_pct")
        ax.grid(True, alpha=0.3)
        fig.tight_layout()
        fig.savefig(plot_dir / f"return_vs_capital{suffix}.png", dpi=160, bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots(figsize=(10, 5))
        ax.scatter(summary["capital_usd"], summary["p95_liquidity_share_pct_when_in_range"])
        ax.set_title("p95 share vs capital")
        ax.set_xlabel("capital_usd")
        ax.set_ylabel("p95 liquidity share %")
        ax.grid(True, alpha=0.3)
        fig.tight_layout()
        fig.savefig(plot_dir / f"p95_share_vs_capital{suffix}.png", dpi=160, bbox_inches="tight")
        plt.close(fig)


def run_walkforward(data: Dict[str, Any], args: argparse.Namespace, out_dir: Path) -> None:
    import pandas as pd
    periods = [
        ("_feb", "2026-02-01T00:00:00Z", "2026-03-01T00:00:00Z"),
        ("_mar", "2026-03-01T00:00:00Z", "2026-04-01T00:00:00Z"),
        ("_apr", "2026-04-01T00:00:00Z", "2026-05-01T00:00:00Z"),
        ("_quarter", "2026-02-01T00:00:00Z", "2026-05-01T00:00:00Z"),
        ("_febmar", "2026-02-01T00:00:00Z", "2026-04-01T00:00:00Z"),
    ]
    outputs = {}
    for suffix, fr, to in periods:
        outputs[suffix] = run_once(data, args, out_dir, fr, to, suffix=suffix)

    # Basic walk-forward: take top score on train, evaluate same strategy/capital on test summaries.
    train_test = [("_feb", "_mar", "tune_feb_eval_mar"), ("_mar", "_apr", "tune_mar_eval_apr"), ("_febmar", "_apr", "tune_febmar_eval_apr")]
    wf_rows = []
    for train_key, test_key, label in train_test:
        train = outputs[train_key].sort_values("score", ascending=False)
        test = outputs[test_key]
        if train.empty:
            continue
        champ = train.iloc[0]
        mask = (
            (test["strategy"] == champ["strategy"])
            & (test["capital_usd"] == champ["capital_usd"])
            & (test["fee_scenario"] == champ["fee_scenario"])
        )
        if mask.any():
            ev = test[mask].iloc[0]
            wf_rows.append({
                "wf_case": label,
                "train_strategy": champ["strategy"],
                "capital_usd": champ["capital_usd"],
                "train_return_pct": champ["return_pct"],
                "train_mdd_pct": champ["mdd_pct"],
                "train_score": champ["score"],
                "test_return_pct": ev["return_pct"],
                "test_mdd_pct": ev["mdd_pct"],
                "test_score": ev["score"],
                "test_p95_share": ev["p95_liquidity_share_pct_when_in_range"],
                "test_p99_share": ev["p99_liquidity_share_pct_when_in_range"],
                "test_max_share": ev["max_liquidity_share_pct_when_in_range"],
            })
    pd.DataFrame(wf_rows).to_csv(out_dir / "walkforward.csv", index=False)

    # Monthly breakdown for top quarter configs.
    top_q = outputs["_quarter"].sort_values("score", ascending=False).head(20)
    monthly_rows = []
    for _, champ in top_q.iterrows():
        for suffix in ["_feb", "_mar", "_apr", "_quarter"]:
            df = outputs[suffix]
            mask = (df["strategy"] == champ["strategy"]) & (df["capital_usd"] == champ["capital_usd"])
            if mask.any():
                r = df[mask].iloc[0].to_dict()
                r["source_period"] = suffix.strip("_")
                monthly_rows.append(r)
    pd.DataFrame(monthly_rows).to_csv(out_dir / "monthly_breakdown.csv", index=False)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--npz", required=True)
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--fee-rates", default="metadata_0_2515:0.002515")
    ap.add_argument("--time-from", default="")
    ap.add_argument("--time-to", default="")
    ap.add_argument("--dec0", type=int, default=0)
    ap.add_argument("--dec1", type=int, default=0)
    ap.add_argument("--strategies", default="")
    ap.add_argument("--grid-lower", default="")
    ap.add_argument("--grid-upper", default="")
    ap.add_argument("--rebalance-grid", default="none:0,periodic:168,periodic:336,oor:24")
    ap.add_argument("--gas-usd", type=float, default=0.0)
    ap.add_argument("--swap-cost-bps", type=float, default=0.0)

    ap.add_argument("--initial-capital-usd", type=float, default=1000.0)
    ap.add_argument("--total-capital-usd", type=float, default=1000.0)
    ap.add_argument("--capital-grid", default="25,50,75,100,125,150,175,200,250,300,500,1000")
    ap.add_argument("--capital-mode", choices=["fixed", "grid", "auto_cap"], default="fixed")
    ap.add_argument("--target-share-pct", type=float, default=5.0)
    ap.add_argument("--share-cap-metric", choices=["avg", "p95", "p99", "max"], default="p95")
    ap.add_argument("--max-avg-liquidity-share-pct", type=float, default=3.0)
    ap.add_argument("--max-p95-liquidity-share-pct", type=float, default=5.0)
    ap.add_argument("--max-p99-liquidity-share-pct", type=float, default=10.0)
    ap.add_argument("--max-liquidity-share-pct", type=float, default=25.0)

    ap.add_argument("--target-mdd-pct", type=float, default=25.0)
    ap.add_argument("--w-mdd", type=float, default=2.0)
    ap.add_argument("--w-avg-share", type=float, default=5.0)
    ap.add_argument("--w-p95-share", type=float, default=10.0)
    ap.add_argument("--w-p99-share", type=float, default=3.0)
    ap.add_argument("--w-max-share", type=float, default=0.5)
    ap.add_argument("--w-rebalance", type=float, default=0.02)

    ap.add_argument("--walkforward", action="store_true")
    ap.add_argument("--plots", action="store_true")
    ap.add_argument("--max-plot-runs", type=int, default=30)
    ap.add_argument("--top-n", type=int, default=50)
    ap.add_argument("--jobs", type=int, default=0, help="Parallel strategy workers. 0 = auto CPU count; 1 = disabled.")
    _add_bool_arg_compat(ap, "--fast-core", default=True, help="Use compiled C replay core when available.")
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    data = load_npz(args.npz)

    if args.walkforward:
        run_walkforward(data, args, out_dir)
    else:
        if not args.time_from or not args.time_to:
            raise SystemExit("Use --time-from and --time-to, or --walkforward")
        run_once(data, args, out_dir, args.time_from, args.time_to, return_frame=False)

    print(json.dumps({"script_version": SCRIPT_VERSION, "out_dir": str(out_dir)}, indent=2))


if __name__ == "__main__":
    main()
