#!/usr/bin/env python3
from __future__ import annotations

import argparse
import asyncio
import datetime as dt
import json
import math
import os
import re
import sqlite3
import time
import urllib.parse
import urllib.request
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional

from telethon import TelegramClient, events

try:
    import ccxt  # type: ignore
except Exception:
    ccxt = None


CHANNEL_ID = 3543793657
CHANNEL_TITLE = "Трейдинг золотом і валютами|Віталій DED USD"

METAL_TO_BINGX = {
    "GOLD": "NCCOGOLD2USD/USDT:USDT",
    "XAUUSD": "NCCOGOLD2USD/USDT:USDT",
    "SILVER": "NCCOXAG2USD/USDT:USDT",
    "XAGUSD": "NCCOXAG2USD/USDT:USDT",
}
FX_TO_YAHOO = {
    "EURUSD": "EURUSD=X",
    "GBPUSD": "GBPUSD=X",
    "USDCAD": "CAD=X",
    "USDCHF": "CHF=X",
}


def utc_now() -> str:
    return dt.datetime.now(dt.timezone.utc).isoformat()


def load_env(path: str) -> Path:
    p = Path(path)
    if p.exists():
        for raw in p.read_text(encoding="utf-8").splitlines():
            line = raw.strip()
            if not line or line.startswith("#") or "=" not in line:
                continue
            k, v = line.split("=", 1)
            os.environ.setdefault(k.strip(), v.strip().strip('"').strip("'"))
    return p.parent if p.exists() else Path.cwd()


def resolve_path(raw: str, base: Path) -> Path:
    p = Path(raw)
    return p if p.is_absolute() else base / p


def nums(raw: str) -> List[float]:
    out: List[float] = []
    for m in re.findall(r"(?<!\d)(\d+(?:[.,]\d+)?)(?!\d)", raw):
        try:
            out.append(float(m.replace(",", ".")))
        except ValueError:
            pass
    return out


def normalize_symbol(text: str) -> Optional[str]:
    up = text.upper().replace("/", "").replace(" ", "")
    for sym in ("XAUUSD", "XAGUSD", "EURUSD", "GBPUSD", "USDCAD", "USDCHF"):
        if sym in up:
            return sym
    raw_up = text.upper()
    if "GOLD" in raw_up:
        return "GOLD"
    if "SILVER" in raw_up:
        return "SILVER"
    return None


def parse_side(text: str) -> Optional[str]:
    lo = text.casefold()
    if any(x in lo for x in ("⬆", "buy", "покуп", "купів")):
        return "LONG"
    if any(x in lo for x in ("⬇", "sell", "продаж", "прода")):
        return "SHORT"
    return None


def is_tp_line(line: str) -> bool:
    return bool(re.search(r"\b(tp|tп|тп)\s*\d*\b", line.casefold()))


def is_stop_line(line: str) -> bool:
    lo = line.casefold()
    return any(x in lo for x in ("стоп", "stop", "стоп лос", "stop loss"))


def is_entry_line(line: str) -> bool:
    lo = line.casefold()
    return any(x in lo for x in ("вход", "вхід", "entry"))


def parse_dedusd_signal(text: str) -> Optional[Dict[str, Any]]:
    lines = [ln.strip() for ln in text.replace("\r", "\n").split("\n") if ln.strip()]
    if not lines:
        return None
    head = "\n".join(lines[:4])
    symbol = normalize_symbol(head) or normalize_symbol(text[:180])
    side = parse_side(head) or parse_side(text[:180])
    if not symbol or not side:
        return None
    tps: List[float] = []
    stop_values: List[float] = []
    entry_values: List[float] = []
    for line in lines:
        values = nums(line)
        if not values:
            continue
        if is_tp_line(line):
            if len(values) >= 2 and values[0] in (1.0, 2.0, 3.0, 4.0, 5.0):
                values = values[1:]
            tps.extend(values[:1])
        elif is_stop_line(line):
            stop_values.extend(values[:1])
        elif is_entry_line(line):
            entry_values.extend(values[:2])
    if not entry_values:
        for line in lines[1:5]:
            values = nums(line)
            if values and not is_tp_line(line) and not is_stop_line(line):
                entry_values.extend(values[:2])
                break
    if not tps or not stop_values:
        return None
    sl = stop_values[0]
    if side == "LONG" and not all(tp > sl for tp in tps):
        return None
    if side == "SHORT" and not all(tp < sl for tp in tps):
        return None
    declared_low = declared_high = None
    if len(entry_values) == 1:
        declared_low = declared_high = entry_values[0]
    elif len(entry_values) >= 2:
        declared_low, declared_high = min(entry_values[0], entry_values[1]), max(entry_values[0], entry_values[1])
    return {
        "symbol": symbol,
        "side": side,
        "declared_entry_low": declared_low,
        "declared_entry_high": declared_high,
        "tp1": tps[0],
        "tp2": tps[1] if len(tps) > 1 else None,
        "tp3": tps[2] if len(tps) > 2 else None,
        "tp4": tps[3] if len(tps) > 3 else None,
        "sl": sl,
        "raw_text": text,
    }


class PriceProvider:
    def __init__(self, asset_class: str):
        self.asset_class = asset_class
        self.exchange = None
        if asset_class == "metals":
            if ccxt is None:
                raise RuntimeError("ccxt is required for BingX metals mark provider")
            self.exchange = ccxt.bingx({"enableRateLimit": True})
            self.exchange.load_markets()

    def supports(self, symbol: str) -> bool:
        return symbol in (METAL_TO_BINGX if self.asset_class == "metals" else FX_TO_YAHOO)

    def price(self, symbol: str) -> Dict[str, Any]:
        if self.asset_class == "metals":
            market = METAL_TO_BINGX[symbol]
            rows = self.exchange.fetch_mark_ohlcv(market, timeframe="1m", limit=2)
            if not rows:
                raise RuntimeError(f"no BingX mark rows for {market}")
            row = rows[-1]
            return {
                "price": float(row[4]),
                "provider_symbol": market,
                "source": "bingx_fetch_mark_ohlcv_1m_close",
                "source_time": dt.datetime.fromtimestamp(row[0] / 1000, dt.timezone.utc).isoformat(),
            }
        yahoo = FX_TO_YAHOO[symbol]
        url = (
            "https://query2.finance.yahoo.com/v8/finance/chart/"
            + urllib.parse.quote(yahoo)
            + "?range=1d&interval=1m&includePrePost=true"
        )
        req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0", "Accept": "application/json"})
        with urllib.request.urlopen(req, timeout=15) as resp:
            data = json.loads(resp.read().decode("utf-8"))
        result = (data.get("chart", {}).get("result") or [None])[0]
        if not result:
            raise RuntimeError(f"no Yahoo chart result for {yahoo}")
        timestamps = result.get("timestamp") or []
        quote = (result.get("indicators", {}).get("quote") or [{}])[0]
        closes = quote.get("close") or []
        for ts, close in reversed(list(zip(timestamps, closes))):
            if close is not None:
                return {
                    "price": float(close),
                    "provider_symbol": yahoo,
                    "source": "yahoo_finance_chart_1m_close_proxy",
                    "source_time": dt.datetime.fromtimestamp(int(ts), dt.timezone.utc).isoformat(),
                }
        raise RuntimeError(f"no Yahoo close for {yahoo}")


def ensure_db(path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    con = sqlite3.connect(path)
    cur = con.cursor()
    cur.execute("""CREATE TABLE IF NOT EXISTS signals(
        telegram_message_id INTEGER PRIMARY KEY,
        ts_utc TEXT,
        received_at TEXT,
        source_channel TEXT,
        symbol TEXT,
        side TEXT,
        declared_entry_low REAL,
        declared_entry_high REAL,
        tp1 REAL,
        tp2 REAL,
        tp3 REAL,
        tp4 REAL,
        sl REAL,
        raw_text TEXT
    )""")
    cur.execute("""CREATE TABLE IF NOT EXISTS positions(
        signal_id INTEGER PRIMARY KEY,
        symbol TEXT,
        side TEXT,
        status TEXT,
        entry_price REAL,
        entry_source TEXT,
        entry_source_time TEXT,
        provider_symbol TEXT,
        notional REAL,
        qty REAL,
        tp1 REAL,
        sl REAL,
        opened_at TEXT,
        updated_at TEXT,
        closed_at TEXT,
        exit_price REAL,
        exit_source TEXT,
        exit_source_time TEXT,
        realized_pnl REAL,
        r_multiple REAL
    )""")
    cur.execute("""CREATE TABLE IF NOT EXISTS price_events(
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        ts_utc TEXT,
        symbol TEXT,
        provider_symbol TEXT,
        price REAL,
        source TEXT,
        source_time TEXT,
        reason TEXT
    )""")
    con.commit()
    con.close()


def exists(db: Path, message_id: int) -> bool:
    con = sqlite3.connect(db)
    cur = con.cursor()
    ok = cur.execute("SELECT 1 FROM signals WHERE telegram_message_id=?", (message_id,)).fetchone() is not None
    con.close()
    return ok


def pnl_for(side: str, entry: float, exit_price: float, notional: float) -> float:
    if side == "LONG":
        return (exit_price - entry) / entry * notional
    return (entry - exit_price) / entry * notional


def r_for(side: str, entry: float, exit_price: float, sl: float) -> float:
    risk = abs(entry - sl)
    if risk <= 0:
        return math.nan
    return (exit_price - entry) / risk if side == "LONG" else (entry - exit_price) / risk


def insert_signal_position(db: Path, msg_id: int, msg_dt: str, sig: Dict[str, Any], px: Dict[str, Any], notional: float) -> None:
    if exists(db, msg_id):
        return
    entry = float(px["price"])
    qty = notional / entry if entry > 0 else 0.0
    now = utc_now()
    con = sqlite3.connect(db)
    cur = con.cursor()
    cur.execute("""INSERT INTO signals VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", (
        msg_id, msg_dt, now, CHANNEL_TITLE, sig["symbol"], sig["side"],
        sig.get("declared_entry_low"), sig.get("declared_entry_high"),
        sig["tp1"], sig.get("tp2"), sig.get("tp3"), sig.get("tp4"), sig["sl"], sig["raw_text"],
    ))
    cur.execute("""INSERT INTO positions VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", (
        msg_id, sig["symbol"], sig["side"], "open", entry, px["source"], px["source_time"],
        px["provider_symbol"], notional, qty, sig["tp1"], sig["sl"], now, now,
        None, None, None, None, 0.0, 0.0,
    ))
    cur.execute("""INSERT INTO price_events(ts_utc,symbol,provider_symbol,price,source,source_time,reason)
                   VALUES (?,?,?,?,?,?,?)""", (now, sig["symbol"], px["provider_symbol"], entry, px["source"], px["source_time"], "entry"))
    con.commit()
    con.close()


def close_position(db: Path, pos: sqlite3.Row, px: Dict[str, Any], reason: str) -> None:
    exit_price = float(px["price"])
    entry = float(pos["entry_price"])
    notional = float(pos["notional"])
    pnl = pnl_for(str(pos["side"]), entry, exit_price, notional)
    r = r_for(str(pos["side"]), entry, exit_price, float(pos["sl"]))
    now = utc_now()
    con = sqlite3.connect(db)
    cur = con.cursor()
    cur.execute("""UPDATE positions SET status=?, updated_at=?, closed_at=?, exit_price=?,
                   exit_source=?, exit_source_time=?, realized_pnl=?, r_multiple=?
                   WHERE signal_id=?""", (
        reason, now, now, exit_price, px["source"], px["source_time"], pnl, r, int(pos["signal_id"])
    ))
    cur.execute("""INSERT INTO price_events(ts_utc,symbol,provider_symbol,price,source,source_time,reason)
                   VALUES (?,?,?,?,?,?,?)""", (now, pos["symbol"], px["provider_symbol"], exit_price, px["source"], px["source_time"], reason))
    con.commit()
    con.close()


async def monitor_open_positions(db: Path, provider: PriceProvider, poll_sec: float) -> None:
    while True:
        try:
            con = sqlite3.connect(db)
            con.row_factory = sqlite3.Row
            positions = list(con.execute("SELECT rowid, * FROM positions WHERE status='open'"))
            con.close()
            for pos in positions:
                px = provider.price(str(pos["symbol"]))
                price = float(px["price"])
                side = str(pos["side"])
                tp1 = float(pos["tp1"])
                sl = float(pos["sl"])
                if side == "LONG" and price <= sl:
                    close_position(db, pos, px, "sl")
                elif side == "LONG" and price >= tp1:
                    close_position(db, pos, px, "tp1")
                elif side == "SHORT" and price >= sl:
                    close_position(db, pos, px, "sl")
                elif side == "SHORT" and price <= tp1:
                    close_position(db, pos, px, "tp1")
        except Exception as exc:
            print(json.dumps({"ok": False, "where": "monitor", "error": repr(exc)}, ensure_ascii=False), flush=True)
        await asyncio.sleep(poll_sec)


async def run(args: argparse.Namespace) -> None:
    env_dir = load_env(args.env_file)
    session = resolve_path(args.session or os.environ.get("TG_SESSION", "runs/telegram_paper/darkknight_session"), env_dir)
    out_jsonl = resolve_path(args.out_jsonl, Path.cwd())
    db = resolve_path(args.db, Path.cwd())
    out_jsonl.parent.mkdir(parents=True, exist_ok=True)
    ensure_db(db)
    provider = PriceProvider(args.asset_class)
    client = TelegramClient(str(session), int(os.environ["TG_API_ID"]), os.environ["TG_API_HASH"])
    await client.connect()
    if not await client.is_user_authorized():
        await client.disconnect()
        raise SystemExit("Telethon session is not authorized")

    target = None
    async for dialog in client.iter_dialogs():
        if getattr(dialog.entity, "id", None) == CHANNEL_ID:
            target = dialog.entity
            break
    if target is None:
        await client.disconnect()
        raise SystemExit(f"channel id {CHANNEL_ID} not found in dialogs")

    @client.on(events.NewMessage(chats=target))
    async def on_message(event: Any) -> None:
        msg = event.message
        text = msg.raw_text or ""
        sig = parse_dedusd_signal(text)
        if not sig or not provider.supports(str(sig["symbol"])):
            return
        msg_id = int(msg.id)
        if exists(db, msg_id):
            return
        try:
            px = provider.price(str(sig["symbol"]))
            insert_signal_position(db, msg_id, msg.date.isoformat() if msg.date else utc_now(), sig, px, args.notional)
            with out_jsonl.open("a", encoding="utf-8") as fp:
                fp.write(json.dumps({"telegram_message_id": msg_id, "telegram_message_date": msg.date.isoformat() if msg.date else None, **sig, "entry_price": px}, ensure_ascii=False) + "\n")
            print(json.dumps({"ok": True, "asset_class": args.asset_class, "message_id": msg_id, "symbol": sig["symbol"], "side": sig["side"], "entry": px}, ensure_ascii=False), flush=True)
        except Exception as exc:
            print(json.dumps({"ok": False, "where": "entry", "message_id": msg_id, "error": repr(exc)}, ensure_ascii=False), flush=True)

    print(json.dumps({"ok": True, "mode": "paper_live", "channel": CHANNEL_TITLE, "asset_class": args.asset_class, "db": str(db), "out_jsonl": str(out_jsonl)}, ensure_ascii=False), flush=True)
    asyncio.create_task(monitor_open_positions(db, provider, args.poll_sec))
    await client.run_until_disconnected()


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--env-file", default="/var/www/vps2.happyuser.info/top/top_1/.env")
    ap.add_argument("--session", default="")
    ap.add_argument("--asset-class", choices=["metals", "fx"], required=True)
    ap.add_argument("--out-jsonl", required=True)
    ap.add_argument("--db", required=True)
    ap.add_argument("--notional", type=float, default=100.0)
    ap.add_argument("--poll-sec", type=float, default=30.0)
    args = ap.parse_args()
    asyncio.run(run(args))


if __name__ == "__main__":
    main()
