#!/usr/bin/env python3
from __future__ import annotations
import argparse, json, time, yaml
import numpy as np
import pandas as pd
from pathlib import Path

try:
    from backtester_dual_core_dynamic_v5 import pick_symbol_block, parse_iso_to_epoch_s, simulate
except ImportError:
    from backtester_dual_core_dynamic_v2 import pick_symbol_block, parse_iso_to_epoch_s, simulate


def save_plot_bundle(curves: pd.DataFrame, plots_dir: str, prefix: str = 'dual') -> dict:
    from pathlib import Path
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    plots_path = Path(plots_dir)
    plots_path.mkdir(parents=True, exist_ok=True)
    df = curves.copy()
    if df.empty:
        return {'plots_dir': str(plots_path), 'plots': []}
    df['bar_ts'] = pd.to_datetime(df['bar_ts'], utc=True, errors='coerce')
    df = df.dropna(subset=['bar_ts']).sort_values('bar_ts')

    generated = []

    def _save(fig, name: str):
        out = plots_path / name
        fig.tight_layout()
        fig.savefig(out, dpi=160, bbox_inches='tight')
        plt.close(fig)
        generated.append(str(out))

    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(df['bar_ts'], df.get('realized_pnl', pd.Series(np.zeros(len(df)))), label='Realized total')
    if 'realized_pnl_long' in df.columns:
        ax.plot(df['bar_ts'], df['realized_pnl_long'], label='Realized long', alpha=0.9)
    if 'realized_pnl_short' in df.columns:
        ax.plot(df['bar_ts'], df['realized_pnl_short'], label='Realized short', alpha=0.9)
    ax.set_title('Dual realized PnL')
    ax.set_xlabel('Time (UTC)')
    ax.set_ylabel('PnL')
    ax.grid(True, alpha=0.3)
    ax.legend()
    _save(fig, f'{prefix}_realized_pnl.png')

    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(df['bar_ts'], df.get('total_pnl', pd.Series(np.zeros(len(df)))), label='Total PnL (MTM)')
    if 'unrealized_pnl' in df.columns:
        ax.plot(df['bar_ts'], df['unrealized_pnl'], label='Unrealized', alpha=0.8)
    if 'realized_pnl' in df.columns:
        ax.plot(df['bar_ts'], df['realized_pnl'], label='Realized', alpha=0.8)
    ax.set_title('Dual MTM PnL')
    ax.set_xlabel('Time (UTC)')
    ax.set_ylabel('PnL')
    ax.grid(True, alpha=0.3)
    ax.legend()
    _save(fig, f'{prefix}_mtm_pnl.png')

    fig, axes = plt.subplots(4, 1, figsize=(13, 12), sharex=True)
    axes[0].plot(df['bar_ts'], df.get('realized_pnl', pd.Series(np.zeros(len(df)))))
    axes[0].set_title('Realized PnL')
    axes[0].grid(True, alpha=0.3)
    axes[1].plot(df['bar_ts'], df.get('unrealized_pnl', pd.Series(np.zeros(len(df)))))
    axes[1].set_title('Unrealized PnL')
    axes[1].grid(True, alpha=0.3)
    axes[2].plot(df['bar_ts'], df.get('total_pnl', pd.Series(np.zeros(len(df)))))
    axes[2].set_title('Total PnL (MTM)')
    axes[2].grid(True, alpha=0.3)
    axes[3].plot(df['bar_ts'], df.get('long_notional', pd.Series(np.zeros(len(df)))), label='Long notional')
    axes[3].plot(df['bar_ts'], df.get('short_notional', pd.Series(np.zeros(len(df)))), label='Short notional')
    axes[3].set_title('Long / Short notional')
    axes[3].grid(True, alpha=0.3)
    axes[3].legend()
    axes[3].set_xlabel('Time (UTC)')
    _save(fig, f'{prefix}_pnl_panels_all.png')

    if 'effective_notional' in df.columns and 'allowed_notional' in df.columns:
        fig, ax = plt.subplots(figsize=(12, 5))
        ax.plot(df['bar_ts'], df['effective_notional'], label='Effective notional')
        ax.plot(df['bar_ts'], df['allowed_notional'], label='Allowed notional')
        if 'margin_excess' in df.columns:
            ax.fill_between(df['bar_ts'], 0, np.maximum(df['margin_excess'].astype(float).to_numpy(), 0.0), alpha=0.25, label='Excess over limit')
        ax.set_title('Margin call excess')
        ax.set_xlabel('Time (UTC)')
        ax.set_ylabel('Notional')
        ax.grid(True, alpha=0.3)
        ax.legend()
        _save(fig, f'{prefix}_margin_call_excess.png')

    return {'plots_dir': str(plots_path), 'plots': generated}



def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--cfg', required=True)
    ap.add_argument('--npz', required=True)
    ap.add_argument('--symbol', default='')
    ap.add_argument('--limit-bars', type=int, default=0)
    ap.add_argument('--time-from', default='')
    ap.add_argument('--warmup-bars', type=int, default=0, help='Include N bars before --time-from to warm strategy indicators, but do not trade them')
    ap.add_argument('--warmup-hours', type=float, default=0.0, help='Include hours before --time-from as warmup; max with --warmup-bars')
    ap.add_argument('--time-to', default='')
    ap.add_argument('--export-curves', default='')
    ap.add_argument('--plots', default='')
    ap.add_argument('--dynamic-slippage-json', default='')
    args = ap.parse_args()
    t0 = time.time()
    cfg = yaml.safe_load(open(args.cfg, 'r', encoding='utf-8'))
    model_override = json.loads(args.dynamic_slippage_json) if args.dynamic_slippage_json else None
    data = np.load(args.npz, allow_pickle=True)
    market_symbol, ts_s, open_, high, low, close, volume, extras = pick_symbol_block(data, args.symbol)
    trade_start_ts_s = None
    if args.time_from:
        tf = parse_iso_to_epoch_s(args.time_from)
        trade_start_ts_s = int(tf)
        warmup_bars = max(0, int(getattr(args, 'warmup_bars', 0) or 0))
        if getattr(args, 'warmup_hours', 0.0):
            try:
                if len(ts_s) >= 2:
                    bar_sec = float(np.median(np.diff(ts_s[:min(len(ts_s), 10000)])))
                else:
                    bar_sec = 60.0
                warmup_bars = max(warmup_bars, int(np.ceil(float(args.warmup_hours) * 3600.0 / max(1.0, bar_sec))))
            except Exception:
                pass
        start_ts = int(tf)
        if warmup_bars > 0 and len(ts_s) >= 2:
            try:
                bar_sec = float(np.median(np.diff(ts_s[:min(len(ts_s), 10000)])))
            except Exception:
                bar_sec = 60.0
            start_ts = int(tf - warmup_bars * max(1.0, bar_sec))
        m = ts_s >= start_ts
        ts_s, close = ts_s[m], close[m]
        open_ = open_[m] if open_ is not None else None
        high = high[m] if high is not None else None
        low = low[m] if low is not None else None
        volume = volume[m] if volume is not None else None
        extras = {k: v[m] for k, v in extras.items()}
    else:
        trade_start_ts_s = None
    if args.time_to:
        tt = parse_iso_to_epoch_s(args.time_to)
        m = ts_s <= tt
        ts_s, close = ts_s[m], close[m]
        open_ = open_[m] if open_ is not None else None
        high = high[m] if high is not None else None
        low = low[m] if low is not None else None
        volume = volume[m] if volume is not None else None
        extras = {k: v[m] for k, v in extras.items()}
    if args.limit_bars and args.limit_bars > 0:
        ts_s, close = ts_s[-args.limit_bars:], close[-args.limit_bars:]
        open_ = open_[-args.limit_bars:] if open_ is not None else None
        high = high[-args.limit_bars:] if high is not None else None
        low = low[-args.limit_bars:] if low is not None else None
        volume = volume[-args.limit_bars:] if volume is not None else None
        extras = {k: v[-args.limit_bars:] for k, v in extras.items()}

    need_curves = bool(args.export_curves or args.plots)
    out = simulate(cfg, ts_s, close, open_=open_, high=high, low=low, volume=volume, extras=extras, market_symbol=market_symbol, model_override=model_override, export_curves=need_curves, trade_start_ts_s=trade_start_ts_s)
    out['elapsed_sec'] = time.time() - t0
    curves = out.pop('curves', None)
    if curves is not None:
        if args.export_curves:
            Path(args.export_curves).parent.mkdir(parents=True, exist_ok=True)
            curves.to_csv(args.export_curves, index=False)
            out['curves_csv'] = args.export_curves
        if args.plots:
            out.update(save_plot_bundle(curves, args.plots, prefix='dual'))
    print(json.dumps(out, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()
