#!/usr/bin/env python3
"""
trend_adaptive_lp_v1.py

Trend-Adaptive LP Strategy for Concentrated Liquidity Pools.

Ідея (як на CEX з DEMA):
  - Обчислює EMA/DEMA на ціновій серії
  - BULLISH (price > DEMA): деплоїти LP з wide-upper range → ловимо зростання + fees
  - BEARISH (price < DEMA): деплоїти LP з wide-lower range АБО виходити в USDC
  - При розвороті: ребаланс LP діапазону

Режими:
  exit_bearish:   при BEARISH = виходити в USDC (нейтральна позиція)
  flip_range:     при BEARISH = перемикати на symmetric wide range
  dema_only:      фільтр входу — деплоїти лише при BULLISH тренді
"""
from __future__ import annotations
import argparse, json, math
from pathlib import Path
import numpy as np
import pandas as pd

SCRIPT_VERSION = "trend_adaptive_lp_v1_2026_05_04"


# ── Індикатори ──────────────────────────────────────────────────────────────

def ema(prices: np.ndarray, period: int) -> np.ndarray:
    """Exponential Moving Average."""
    alpha = 2.0 / (period + 1)
    result = np.empty_like(prices)
    result[0] = prices[0]
    for i in range(1, len(prices)):
        result[i] = alpha * prices[i] + (1 - alpha) * result[i - 1]
    return result

def dema(prices: np.ndarray, period: int) -> np.ndarray:
    """Double EMA: DEMA = 2*EMA - EMA(EMA). Швидше реагує на розворот."""
    e1 = ema(prices, period)
    e2 = ema(e1, period)
    return 2 * e1 - e2

def sma(prices: np.ndarray, period: int) -> np.ndarray:
    """Simple Moving Average."""
    result = np.full_like(prices, np.nan)
    for i in range(period - 1, len(prices)):
        result[i] = prices[i - period + 1:i + 1].mean()
    return result


# ── Uniswap V3 позиція math ─────────────────────────────────────────────────

def sqrt_raw(price, dec0=6, dec1=18):
    return math.sqrt(10 ** (dec1 - dec0) / max(price, 1e-300))

def liquidity_for_capital(capital, p0, lower, upper, dec0=6, dec1=18):
    sp = sqrt_raw(p0, dec0, dec1)
    sa = sqrt_raw(upper, dec0, dec1)
    sb = sqrt_raw(lower, dec0, dec1)
    if p0 <= lower:
        uval = (sb - sa) / 10**dec1 * p0
    elif p0 >= upper:
        uval = (sb - sa) / (sa * sb) / 10**dec0
    else:
        uval = (sb - sp) / (sp * sb) / 10**dec0 + (sp - sa) / 10**dec1 * p0
    return capital / uval if uval > 1e-300 else 0

def position_value(L, price, lower, upper, dec0=6, dec1=18):
    sp = sqrt_raw(price, dec0, dec1)
    sa = sqrt_raw(upper, dec0, dec1)
    sb = sqrt_raw(lower, dec0, dec1)
    if price <= lower:
        a0, a1 = 0.0, L * (sb - sa)
    elif price >= upper:
        a0, a1 = L * (sb - sa) / (sa * sb), 0.0
    else:
        a0 = L * (sb - sp) / (sp * sb)
        a1 = L * (sp - sa)
    return a0 / 10**dec0 + a1 / 10**dec1 * price

def max_drawdown(equity: np.ndarray) -> float:
    peak = np.maximum.accumulate(equity)
    dd = equity / np.where(peak == 0, np.nan, peak) - 1.0
    return float(np.nanmin(dd) * 100.0)


# ── Основна симуляція ────────────────────────────────────────────────────────

def run_trend_adaptive(
    prices: np.ndarray,
    volumes: np.ndarray,
    active_liq: np.ndarray,
    ts: np.ndarray,
    capital: float,
    fee_rate: float,
    dec0: int,
    dec1: int,
    # Trend params
    ma_type: str,        # 'ema' | 'dema' | 'sma'
    ma_period: int,      # periods (in swap-events)
    # Bearish mode
    bearish_mode: str,   # 'exit' | 'wide_lower' | 'stay'
    # Range params
    bull_lower_pct: float,   # % below entry in bull trend
    bull_upper_pct: float,   # % above entry in bull trend
    bear_lower_pct: float,   # % below entry in bear trend (if wide_lower)
    bear_upper_pct: float,   # % above entry in bear trend
    # Rebalance cooldown
    rebalance_cooldown_h: float = 168,  # min hours between rebalances
) -> dict:
    """Run trend-adaptive LP simulation."""
    n = len(prices)
    
    # Compute MA
    if ma_type == 'dema':
        ma_vals = dema(prices, ma_period)
    elif ma_type == 'sma':
        ma_vals = sma(prices, ma_period)
    else:
        ma_vals = ema(prices, ma_period)
    
    # State
    equity_arr = np.empty(n, dtype=np.float64)
    trend_arr = np.zeros(n, dtype=np.int8)  # 1=bull, -1=bear
    in_range_arr = np.zeros(n, dtype=np.int8)
    in_lp_arr = np.zeros(n, dtype=np.int8)
    share_arr = np.zeros(n, dtype=np.float64)
    
    cash = capital  # USDC when not in LP
    fees_cum = 0.0
    pos_L = 0.0
    pos_lower = 0.0
    pos_upper = 0.0
    pos_entry = 0.0
    in_lp = False
    last_rebal_ts = ts[0]
    rebalances = 0
    trend_changes = 0
    prev_trend = 0
    
    def enter_lp(p0, lower_pct, upper_pct, cap):
        nonlocal pos_L, pos_lower, pos_upper, pos_entry, in_lp, cash
        lo = p0 * (1 - lower_pct / 100)
        up = p0 * (1 + upper_pct / 100)
        L = liquidity_for_capital(cap, p0, lo, up, dec0, dec1)
        if L <= 0:
            return
        pos_lower, pos_upper, pos_L, pos_entry = lo, up, L, p0
        in_lp = True
        cash = 0.0
    
    def exit_lp(p_now):
        nonlocal pos_L, pos_lower, pos_upper, pos_entry, in_lp, cash, fees_cum
        if not in_lp:
            return
        val = position_value(pos_L, p_now, pos_lower, pos_upper, dec0, dec1)
        cash = val + fees_cum
        fees_cum = 0.0
        pos_L = pos_lower = pos_upper = pos_entry = 0.0
        in_lp = False
    
    for i in range(n):
        p = prices[i]
        al = active_liq[i]
        v = volumes[i]
        t = ts[i]
        
        # Trend signal
        ma = ma_vals[i]
        trend = 1 if (not np.isnan(ma) and p > ma) else -1
        trend_arr[i] = trend
        
        cooldown_ok = (t - last_rebal_ts) >= rebalance_cooldown_h * 3600
        
        # Trend change → rebalance?
        if trend != prev_trend and prev_trend != 0 and cooldown_ok:
            trend_changes += 1
            if in_lp:
                exit_lp(p)
            if trend == 1:
                # Switched to BULLISH → enter bull LP
                enter_lp(p, bull_lower_pct, bull_upper_pct, cash)
                rebalances += 1
            else:
                # Switched to BEARISH
                if bearish_mode == 'wide_lower':
                    enter_lp(p, bear_lower_pct, bear_upper_pct, cash)
                    rebalances += 1
                elif bearish_mode == 'stay':
                    # Renter with same params but centered on new price
                    enter_lp(p, bull_lower_pct, bull_upper_pct, cash)
                    rebalances += 1
                # 'exit' → remain in cash/USDC
            last_rebal_ts = t
        
        # Initial entry (first candle with valid MA)
        if not in_lp and cash > 0 and not np.isnan(ma):
            if trend == 1:
                enter_lp(p, bull_lower_pct, bull_upper_pct, cash)
            elif bearish_mode == 'wide_lower':
                enter_lp(p, bear_lower_pct, bear_upper_pct, cash)
        
        prev_trend = trend
        
        # Compute equity and fees
        if in_lp:
            in_range = (p >= pos_lower) & (p <= pos_upper)
            in_range_arr[i] = int(in_range)
            in_lp_arr[i] = 1
            if in_range and al > 0:
                share = pos_L / (al + pos_L)
                share_arr[i] = share
                fee = share * fee_rate * v
                fees_cum += fee
            pos_val = position_value(pos_L, p, pos_lower, pos_upper, dec0, dec1)
            equity_arr[i] = pos_val + fees_cum
        else:
            equity_arr[i] = cash
            in_range_arr[i] = 0
    
    # Final exit
    exit_lp(prices[-1])
    equity_arr[-1] = cash
    
    in_range_in_lp = share_arr[in_lp_arr.astype(bool) & in_range_arr.astype(bool)]
    
    return {
        'return_pct': (equity_arr[-1] / capital - 1.0) * 100,
        'mdd_pct': max_drawdown(equity_arr),
        'equity_end': float(equity_arr[-1]),
        'fees_earned': float(fees_cum),
        'rebalances': rebalances,
        'trend_changes': trend_changes,
        'time_in_lp_pct': float(in_lp_arr.mean() * 100),
        'time_in_range_pct': float(in_range_arr.mean() * 100),
        'avg_share_pct': float(in_range_in_lp.mean() * 100) if len(in_range_in_lp) else 0,
        'p99_share_pct': float(np.percentile(in_range_in_lp, 99) * 100) if len(in_range_in_lp) else 0,
        'price_start': float(prices[0]),
        'price_end': float(prices[-1]),
        'ma_type': ma_type,
        'ma_period': ma_period,
        'bearish_mode': bearish_mode,
    }


# ── Grid tuning ──────────────────────────────────────────────────────────────

def grid_tune(npz_path: str, out_dir: str, capital: float, fee_rate: float,
              dec0: int, dec1: int, days: float,
              ma_types, ma_periods,
              bearish_modes, cooldowns,
              bull_lowers, bull_uppers,
              bear_lowers, bear_uppers):
    
    z = np.load(npz_path, allow_pickle=False)
    prices = z['price'].astype(np.float64)
    volumes = z['input_usd'].astype(np.float64)
    active_liq = z['active_liquidity'].astype(np.float64)
    ts = z['ts'].astype(np.int64)
    
    rows = []
    total = (len(ma_types)*len(ma_periods)*len(bearish_modes)*len(cooldowns)*
             len(bull_lowers)*len(bull_uppers))
    print(f"Grid: {total} configs on {len(prices)} swaps ({days:.1f}d)")
    
    done = 0
    for ma_t in ma_types:
        for ma_p in ma_periods:
            for bmode in bearish_modes:
                for cd in cooldowns:
                    for bl in bull_lowers:
                        for bu in bull_uppers:
                            bear_l = bear_lowers[0] if bmode != 'exit' else bl
                            bear_u = bear_uppers[0] if bmode != 'exit' else bu
                            r = run_trend_adaptive(
                                prices, volumes, active_liq, ts,
                                capital, fee_rate, dec0, dec1,
                                ma_t, ma_p, bmode, bl, bu, bear_l, bear_u, cd)
                            pnl_mdd = abs(r['return_pct'])/abs(r['mdd_pct']) if r['mdd_pct'] else 0
                            rows.append({
                                'ma_type': ma_t, 'ma_period': ma_p,
                                'bearish_mode': bmode, 'cooldown_h': cd,
                                'bull_lower_pct': bl, 'bull_upper_pct': bu,
                                'bear_lower_pct': bear_l, 'bear_upper_pct': bear_u,
                                **r,
                                'annual_pct': r['return_pct'] / days * 365,
                                'pnl_mdd': pnl_mdd,
                            })
                            done += 1
    
    df = pd.DataFrame(rows)
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)
    df.to_csv(out / 'summary.csv', index=False)
    
    valid = df[(df['mdd_pct'] >= -20) & (df['pnl_mdd'] >= 2) & (df['return_pct'] > 0)]
    df.sort_values('annual_pct', ascending=False).head(30).to_csv(out / 'best_by_return.csv', index=False)
    valid.sort_values('annual_pct', ascending=False).head(30).to_csv(out / 'best_valid.csv', index=False)
    
    print(f"Done: {len(df)} configs, {len(valid)} valid (MDD<20%, PnL/MDD>2)")
    if len(valid):
        b = valid.sort_values('annual_pct', ascending=False).iloc[0]
        print(f"Best: {b['ma_type']}({b['ma_period']}) {b['bearish_mode']} "
              f"bull={b['bull_lower_pct']:.0f}/{b['bull_upper_pct']:.0f} "
              f"→ {b['annual_pct']:.0f}% ann MDD={b['mdd_pct']:.1f}%")
    return df, valid


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--npz', required=True)
    ap.add_argument('--out-dir', required=True)
    ap.add_argument('--capital', type=float, default=600)
    ap.add_argument('--fee-rate', type=float, default=0.003)
    ap.add_argument('--dec0', type=int, default=6)
    ap.add_argument('--dec1', type=int, default=18)
    ap.add_argument('--days', type=float, default=30)
    ap.add_argument('--ma-types', default='dema,ema,sma')
    ap.add_argument('--ma-periods', default='50,100,200,500,1000,2000')
    ap.add_argument('--bearish-modes', default='exit,wide_lower,stay')
    ap.add_argument('--cooldowns', default='24,72,168')
    ap.add_argument('--bull-lowers', default='5,10,15,20,30,40')
    ap.add_argument('--bull-uppers', default='10,20,30,40,50,70,90')
    ap.add_argument('--bear-lowers', default='40,60,80')
    ap.add_argument('--bear-uppers', default='5,10,20')
    args = ap.parse_args()
    
    print(f"[{SCRIPT_VERSION}]")
    grid_tune(
        args.npz, args.out_dir, args.capital, args.fee_rate,
        args.dec0, args.dec1, args.days,
        args.ma_types.split(','),
        [int(x) for x in args.ma_periods.split(',')],
        args.bearish_modes.split(','),
        [float(x) for x in args.cooldowns.split(',')],
        [float(x) for x in args.bull_lowers.split(',')],
        [float(x) for x in args.bull_uppers.split(',')],
        [float(x) for x in args.bear_lowers.split(',')],
        [float(x) for x in args.bear_uppers.split(',')],
    )

if __name__ == '__main__':
    main()
