#!/usr/bin/env python3
"""Example market-making bot for EventTrader Exchange.

A simple two-sided market maker that places bid and ask orders around the
mid price, tracks fills, and manages position/P&L.

Supports two modes:
  1. Standalone — place orders directly on the CLOB (paper or live).
  2. Clone Mode — create or attach to a cloned bot, fund it from your
     account balance, and trade with risk controls (SL/TP).

Setup:
    pip install event-trader

Usage:
    # Paper mode (default — no real money):
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_...

    # Live mode:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... --live

    # With email/password auth:
    python scripts/example_exchange_bot.py --pair vaix --email user@example.com --password secret

    # Custom spread and size:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... --spread-bps 100 --size 500

    # Clone Mode — create a new clone from a species and fund it:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... \\
        --clone-from macd --clone-name "My MACD MM" --fund-amount 100

    # Clone Mode — attach to an existing clone:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... --clone-id <id>

    # Clone Mode — with risk controls and auto-withdraw on stop:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... \\
        --clone-from momentum --fund-amount 200 --max-loss 10 --max-position 5000 --withdraw-on-stop

    # Clone Mode — equip the market_making skill for spread capture:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... \\
        --clone-from macd --fund-amount 100 --equip-skill market_making

    # Clone Mode — stack multiple skills:
    python scripts/example_exchange_bot.py --pair vaix --api-key evt_... \\
        --clone-from momentum --fund-amount 200 \\
        --equip-skill market_making --equip-skill rsi_momentum --equip-skill volatility_filter

Environment variables (alternative to CLI args):
    ET_API_KEY        API key
    ET_EMAIL          Email for login
    ET_PASSWORD       Password for login
    ET_BASE_URL       API base URL (default: https://cymetica.com)
"""

from __future__ import annotations

import argparse
import asyncio
import logging
import os
import signal
import sys
from datetime import datetime, timezone
from decimal import Decimal, ROUND_DOWN
from typing import Any

from event_trader import EventTrader

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("mm-bot")


# ---------------------------------------------------------------------------
# Position tracker
# ---------------------------------------------------------------------------

class PositionTracker:
    """Tracks inventory, realized P&L, and unrealized P&L."""

    def __init__(self, symbol: str):
        self.symbol = symbol.upper()
        self.inventory = Decimal("0")       # net base token position
        self.realized_pnl = Decimal("0")    # USDC in from sells - out from buys
        self.avg_entry = Decimal("0")        # average cost basis
        self.trade_count = 0
        self._seen_trade_ids: set[str] = set()

    def process_fills(self, trades: list[dict]) -> int:
        """Process new fills and update position. Returns count of new fills."""
        new_fills = 0
        for t in trades:
            tid = t.get("trade_id") or t.get("id") or ""
            if tid in self._seen_trade_ids:
                continue
            self._seen_trade_ids.add(tid)
            new_fills += 1
            self.trade_count += 1

            qty = Decimal(str(t.get("quantity", "0")))
            price = Decimal(str(t.get("price", "0")))
            side = t.get("side", "").lower()

            if side == "buy":
                cost = qty * price
                # Update average entry price
                if self.inventory >= 0:
                    total_cost = self.avg_entry * self.inventory + cost
                    self.inventory += qty
                    self.avg_entry = (total_cost / self.inventory) if self.inventory else Decimal("0")
                else:
                    # Closing short position
                    self.realized_pnl += qty * (self.avg_entry - price)
                    self.inventory += qty
            elif side == "sell":
                if self.inventory > 0:
                    # Closing long position
                    self.realized_pnl += qty * (price - self.avg_entry)
                    self.inventory -= qty
                else:
                    # Opening/extending short
                    total_cost = abs(self.avg_entry * self.inventory) + qty * price
                    self.inventory -= qty
                    self.avg_entry = (total_cost / abs(self.inventory)) if self.inventory else Decimal("0")

        return new_fills

    def unrealized_pnl(self, mid_price: Decimal) -> Decimal:
        if self.inventory == 0 or mid_price == 0:
            return Decimal("0")
        if self.inventory > 0:
            return self.inventory * (mid_price - self.avg_entry)
        return abs(self.inventory) * (self.avg_entry - mid_price)

    def total_pnl(self, mid_price: Decimal) -> Decimal:
        return self.realized_pnl + self.unrealized_pnl(mid_price)

    def summary(self, mid_price: Decimal) -> str:
        upnl = self.unrealized_pnl(mid_price)
        total = self.realized_pnl + upnl
        return (
            f"Position: {self.inventory:+} {self.symbol} | "
            f"Avg entry: {self.avg_entry:.6f} | "
            f"Realized: {self.realized_pnl:+.4f} USDC | "
            f"Unrealized: {upnl:+.4f} USDC | "
            f"Total P&L: {total:+.4f} USDC | "
            f"Trades: {self.trade_count}"
        )


# ---------------------------------------------------------------------------
# Price helpers
# ---------------------------------------------------------------------------

def round_to_tick(price: Decimal, tick_size: Decimal) -> Decimal:
    """Round price down to the nearest tick size."""
    if tick_size <= 0:
        return price
    return (price / tick_size).to_integral_value(rounding=ROUND_DOWN) * tick_size


def round_to_lot(qty: Decimal, lot_size: Decimal) -> Decimal:
    """Round quantity down to the nearest lot size."""
    if lot_size <= 0:
        return qty
    return (qty / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size


# ---------------------------------------------------------------------------
# Clone bot manager — wraps /api/v1/cloned-bots/ REST API
# ---------------------------------------------------------------------------

class CloneBotManager:
    """Thin async wrapper around the cloned-bots REST API."""

    BASE = "/api/v1/cloned-bots"

    def __init__(self, http: Any):
        self._http = http

    async def create_clone(
        self,
        source_type: str,
        source_id: str,
        name: str | None = None,
        is_paper: bool | None = None,
    ) -> dict:
        return await self._http.post(
            f"{self.BASE}/clone",
            json={
                "source_type": source_type,
                "source_id": source_id,
                "custom_name": name,
                "is_paper": is_paper,
            },
        )

    async def get_profile(self, clone_id: str) -> dict:
        return await self._http.get(f"{self.BASE}/{clone_id}/profile")

    async def get_balance(self, clone_id: str) -> dict:
        return await self._http.get(f"{self.BASE}/{clone_id}/balance")

    async def fund(self, clone_id: str, amount: float) -> dict:
        return await self._http.post(
            f"{self.BASE}/{clone_id}/fund",
            json={"amount": amount},
        )

    async def withdraw(self, clone_id: str, amount: float) -> dict:
        return await self._http.post(
            f"{self.BASE}/{clone_id}/withdraw",
            json={"amount": amount},
        )

    async def update_settings(self, clone_id: str, settings: dict) -> dict:
        return await self._http.patch(
            f"{self.BASE}/{clone_id}/settings",
            json=settings,
        )

    async def toggle_trading(self, clone_id: str, enabled: bool) -> dict:
        return await self._http.post(
            f"{self.BASE}/{clone_id}/toggle-trading",
            params={"enabled": str(enabled).lower()},
        )

    async def get_trades(self, clone_id: str, page: int = 1, page_size: int = 20) -> dict:
        return await self._http.get(
            f"{self.BASE}/{clone_id}/activity",
            params={"page": page, "page_size": page_size},
        )

    async def list_my_clones(self) -> dict:
        return await self._http.get(f"{self.BASE}/my-clones")

    async def delete_clone(self, clone_id: str) -> dict:
        return await self._http.delete(f"{self.BASE}/{clone_id}")

    # --- Skills ---

    async def get_skills(self, clone_id: str) -> dict:
        return await self._http.get(f"{self.BASE}/{clone_id}/skills")

    async def equip_skill(self, clone_id: str, skill_id: str, slot_position: int | None = None) -> dict:
        body: dict = {"skill_id": skill_id}
        if slot_position is not None:
            body["slot_position"] = slot_position
        return await self._http.post(
            f"{self.BASE}/{clone_id}/skills/equip",
            json=body,
        )

    async def unequip_skill(self, clone_id: str, skill_id: str) -> dict:
        return await self._http.delete(f"{self.BASE}/{clone_id}/skills/{skill_id}")


# ---------------------------------------------------------------------------
# Main bot
# ---------------------------------------------------------------------------

class MarketMaker:
    def __init__(
        self,
        client: EventTrader,
        symbol: str,
        spread_bps: int,
        order_size: Decimal,
        refresh_interval: float,
        mode: str,
        drift_threshold_bps: int,
        *,
        max_session_loss: Decimal | None = None,
        max_position_size: Decimal | None = None,
        clone_stop_loss_pct: float | None = None,
        clone_take_profit_pct: float | None = None,
        initial_balance: Decimal = Decimal("0"),
    ):
        self.client = client
        self.symbol = symbol
        self.spread_bps = spread_bps
        self.order_size = order_size
        self.refresh_interval = refresh_interval
        self.mode = mode
        self.drift_threshold_bps = drift_threshold_bps
        self.tracker = PositionTracker(symbol)
        self.running = True
        self.tick_size = Decimal("0.000001")
        self.lot_size = Decimal("1")
        self.min_order_size = Decimal("1")
        self._open_order_ids: list[str] = []

        # Risk controls
        self.max_session_loss = max_session_loss
        self.max_position_size = max_position_size
        self.clone_stop_loss_pct = clone_stop_loss_pct
        self.clone_take_profit_pct = clone_take_profit_pct
        self.initial_balance = initial_balance

    def check_risk_limits(self, mid_price: Decimal) -> bool:
        """Check risk limits. Returns True if trading should continue."""
        if mid_price <= 0:
            return True
        total_pnl = self.tracker.total_pnl(mid_price)

        if self.max_session_loss is not None and total_pnl <= -self.max_session_loss:
            log.warning("MAX SESSION LOSS hit: P&L %.4f <= -%.4f — stopping", total_pnl, self.max_session_loss)
            return False

        if self.clone_stop_loss_pct is not None and self.initial_balance > 0:
            sl_amount = self.initial_balance * Decimal(str(self.clone_stop_loss_pct)) / Decimal("100")
            if total_pnl <= -sl_amount:
                log.warning("CLONE STOP-LOSS hit: P&L %.4f <= -%.4f (%.1f%%) — stopping", total_pnl, sl_amount, self.clone_stop_loss_pct)
                return False

        if self.clone_take_profit_pct is not None and self.initial_balance > 0:
            tp_amount = self.initial_balance * Decimal(str(self.clone_take_profit_pct)) / Decimal("100")
            if total_pnl >= tp_amount:
                log.info("CLONE TAKE-PROFIT hit: P&L %.4f >= +%.4f (%.1f%%) — stopping", total_pnl, tp_amount, self.clone_take_profit_pct)
                return False

        return True

    async def initialize(self):
        """Fetch pair info and check balance."""
        log.info("Fetching pair info for %s...", self.symbol)
        try:
            pairs_resp = await self.client.clob.list_pairs()
            pairs = pairs_resp if isinstance(pairs_resp, list) else pairs_resp.get("pairs", [])
            for p in pairs:
                sym = (p.get("symbol") or p.get("pair") or "").lower()
                base = (p.get("base") or p.get("base_symbol") or "").lower()
                if sym == self.symbol.lower() or base == self.symbol.lower():
                    self.tick_size = Decimal(str(p.get("tick_size", self.tick_size)))
                    self.lot_size = Decimal(str(p.get("lot_size", self.lot_size)))
                    self.min_order_size = Decimal(str(p.get("min_order_size", self.min_order_size)))
                    log.info(
                        "Pair: %s | tick=%.8f | lot=%.4f | min_order=%.4f",
                        sym, self.tick_size, self.lot_size, self.min_order_size,
                    )
                    break
            else:
                log.warning("Pair %s not found in list_pairs — using defaults", self.symbol)
        except Exception as e:
            log.warning("Could not fetch pair info: %s — using defaults", e)

        try:
            bal = await self.client.clob.balance(self.symbol)
            log.info("Balance: %s", bal)
        except Exception as e:
            log.warning("Could not fetch balance: %s", e)

    async def cancel_all(self):
        """Cancel all open orders."""
        if not self._open_order_ids:
            return
        log.info("Cancelling %d open orders...", len(self._open_order_ids))
        for oid in self._open_order_ids:
            try:
                await self.client.clob.cancel_order(self.symbol, oid, mode=self.mode)
                log.info("  Cancelled %s", oid)
            except Exception as e:
                log.warning("  Failed to cancel %s: %s", oid, e)
        self._open_order_ids.clear()

    async def get_mid_price(self) -> Decimal | None:
        """Get mid price from BBO."""
        try:
            bbo = await self.client.clob.bbo(self.symbol)
            best_bid = bbo.get("best_bid") or bbo.get("bid")
            best_ask = bbo.get("best_ask") or bbo.get("ask")
            if best_bid and best_ask:
                bid = Decimal(str(best_bid))
                ask = Decimal(str(best_ask))
                if bid > 0 and ask > 0:
                    return (bid + ask) / 2
            # Fallback: use stats
            stats = await self.client.clob.stats(self.symbol)
            price = stats.get("price") or stats.get("last_price")
            if price:
                return Decimal(str(price))
        except Exception as e:
            log.warning("Could not get mid price: %s", e)
        return None

    async def refresh_open_orders(self):
        """Sync our tracked order IDs with what's actually open."""
        try:
            result = await self.client.clob.open_orders(self.symbol, mode=self.mode)
            orders = result if isinstance(result, list) else result.get("orders", [])
            self._open_order_ids = [o.get("order_id", o.get("id", "")) for o in orders]
        except Exception as e:
            log.warning("Could not refresh open orders: %s", e)

    async def check_fills(self):
        """Check for new fills and update position."""
        try:
            result = await self.client.clob.my_trades(self.symbol, limit=20)
            trades = result if isinstance(result, list) else result.get("trades", [])
            new = self.tracker.process_fills(trades)
            if new:
                log.info("*** %d new fill(s) ***", new)
        except Exception as e:
            log.warning("Could not check fills: %s", e)

    async def run_cycle(self, last_mid: Decimal | None) -> Decimal | None:
        """Run one quoting cycle. Returns the current mid price."""
        mid = await self.get_mid_price()
        if mid is None or mid == 0:
            log.warning("No mid price available — skipping cycle")
            return last_mid

        # Risk check before trading
        if not self.check_risk_limits(mid):
            self.running = False
            return mid

        # Check for fills
        await self.check_fills()

        # Position-aware order sizing
        qty = round_to_lot(self.order_size, self.lot_size)
        bid_qty = qty
        ask_qty = qty
        if self.max_position_size is not None:
            current_pos = abs(self.tracker.inventory)
            room = self.max_position_size - current_pos
            if room <= 0:
                # Can only place reducing orders
                if self.tracker.inventory > 0:
                    bid_qty = Decimal("0")
                    ask_qty = round_to_lot(min(qty, self.tracker.inventory), self.lot_size)
                elif self.tracker.inventory < 0:
                    ask_qty = Decimal("0")
                    bid_qty = round_to_lot(min(qty, abs(self.tracker.inventory)), self.lot_size)
                else:
                    bid_qty = Decimal("0")
                    ask_qty = Decimal("0")
            else:
                capped = round_to_lot(min(qty, room), self.lot_size)
                # Cap the side that would increase position
                if self.tracker.inventory > 0:
                    bid_qty = capped       # buying adds to long
                elif self.tracker.inventory < 0:
                    ask_qty = capped        # selling adds to short
                else:
                    # Flat — cap both sides
                    bid_qty = capped
                    ask_qty = capped

        # Decide whether to re-quote
        should_requote = not self._open_order_ids  # no orders out
        if last_mid and last_mid > 0 and not should_requote:
            drift = abs(mid - last_mid) / last_mid * Decimal("10000")
            if drift > self.drift_threshold_bps:
                log.info("Mid drifted %.1f bps — re-quoting", drift)
                should_requote = True

        if should_requote:
            await self.cancel_all()

            spread = mid * Decimal(str(self.spread_bps)) / Decimal("10000")
            half_spread = spread / 2

            bid_price = round_to_tick(mid - half_spread, self.tick_size)
            ask_price = round_to_tick(mid + half_spread + self.tick_size, self.tick_size)

            # Ensure bid < ask
            if bid_price >= ask_price:
                ask_price = bid_price + self.tick_size

            log.info("Quoting: BID %.8f x %s | ASK %.8f x %s", bid_price, bid_qty, ask_price, ask_qty)

            # Place bid
            if bid_qty >= self.min_order_size:
                try:
                    result = await self.client.clob.place_order(
                        self.symbol, "buy", str(bid_qty),
                        price=str(bid_price), order_type="post_only", mode=self.mode,
                    )
                    oid = result.get("order_id", result.get("id", ""))
                    if oid:
                        self._open_order_ids.append(oid)
                    log.info("  BID placed: %s", oid or result)
                except Exception as e:
                    log.warning("  BID failed: %s", e)

            # Place ask
            if ask_qty >= self.min_order_size:
                try:
                    result = await self.client.clob.place_order(
                        self.symbol, "sell", str(ask_qty),
                        price=str(ask_price), order_type="post_only", mode=self.mode,
                    )
                    oid = result.get("order_id", result.get("id", ""))
                    if oid:
                        self._open_order_ids.append(oid)
                    log.info("  ASK placed: %s", oid or result)
                except Exception as e:
                    log.warning("  ASK failed: %s", e)

            if bid_qty < self.min_order_size and ask_qty < self.min_order_size:
                log.warning("Both sides below min order size — skipping")
        else:
            # Refresh tracked orders (some may have filled)
            await self.refresh_open_orders()

        # Log status
        log.info("Mid: %.8f | Open orders: %d | %s", mid, len(self._open_order_ids), self.tracker.summary(mid))
        return mid

    async def run(self, clone_info: dict | None = None, clone_manager: CloneBotManager | None = None, clone_id: str | None = None):
        """Main loop."""
        log.info("=" * 70)
        if clone_info:
            log.info("EventTrader Market Maker Bot (Clone Mode)")
            log.info("  Clone:    %s (id: %s)", clone_info.get("name", "?"), clone_info.get("id", "?"))
            log.info("  Source:   %s (%s)", clone_info.get("source_name", "?"), clone_info.get("type", "?"))
            if self.initial_balance > 0:
                log.info("  Balance:  %.2f USDC", self.initial_balance)
            sl = f"{self.clone_stop_loss_pct:.1f}%" if self.clone_stop_loss_pct else "off"
            tp = f"{self.clone_take_profit_pct:.1f}%" if self.clone_take_profit_pct else "off"
            log.info("  SL/TP:    %s / %s", sl, tp)
            skills = clone_info.get("skills", [])
            if skills:
                log.info("  Skills:   %s", ", ".join(skills))
        else:
            log.info("EventTrader Market Maker Bot")
        log.info("  Pair:     %s/USDC", self.symbol.upper())
        log.info("  Spread:   %d bps (%.2f%%)", self.spread_bps, self.spread_bps / 100)
        log.info("  Size:     %s per side", self.order_size)
        log.info("  Refresh:  %ss", self.refresh_interval)
        log.info("  Mode:     %s", self.mode.upper())
        log.info("  Drift:    %d bps re-quote threshold", self.drift_threshold_bps)
        if self.max_session_loss is not None:
            log.info("  Max loss: %.2f USDC", self.max_session_loss)
        if self.max_position_size is not None:
            log.info("  Max pos:  %s", self.max_position_size)
        log.info("=" * 70)

        await self.initialize()

        last_mid: Decimal | None = None
        cycle = 0
        while self.running:
            cycle += 1
            log.info("--- Cycle %d ---", cycle)
            try:
                last_mid = await self.run_cycle(last_mid)
            except Exception as e:
                log.error("Cycle error: %s", e, exc_info=True)

            # Sync clone balance every 5th cycle
            if clone_manager and clone_id and cycle % 5 == 0:
                try:
                    bal_resp = await clone_manager.get_balance(clone_id)
                    server_bal = Decimal(str(bal_resp.get("available") or bal_resp.get("balance") or bal_resp.get("total", 0)))
                    local_bal = self.initial_balance + self.tracker.total_pnl(last_mid or Decimal("0"))
                    if local_bal > 0:
                        divergence = abs(server_bal - local_bal) / local_bal * 100
                        if divergence > 5:
                            log.warning("Balance divergence: server=%.2f local=%.2f (%.1f%%)", server_bal, local_bal, divergence)
                    log.info("Clone server balance: %.2f USDC", server_bal)
                except Exception as e:
                    log.debug("Could not sync clone balance: %s", e)

            # Sleep in small increments so shutdown is responsive
            for _ in range(int(self.refresh_interval * 10)):
                if not self.running:
                    break
                await asyncio.sleep(0.1)

        # Shutdown
        log.info("Shutting down...")
        await self.cancel_all()
        mid = last_mid or Decimal("0")
        log.info("Final: %s", self.tracker.summary(mid))


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="EventTrader example market-making bot",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    p.add_argument("--pair", default="vaix", help="Trading pair symbol (default: vaix)")
    p.add_argument("--api-key", default=os.environ.get("ET_API_KEY"), help="API key")
    p.add_argument("--email", default=os.environ.get("ET_EMAIL"), help="Email for login")
    p.add_argument("--username", default=os.environ.get("ET_USERNAME"), help="Username for login (alternative to email)")
    p.add_argument("--password", default=os.environ.get("ET_PASSWORD"), help="Password for login")
    p.add_argument("--base-url", default=os.environ.get("ET_BASE_URL", "https://cymetica.com"), help="API base URL")
    p.add_argument("--spread-bps", type=int, default=50, help="Spread in basis points (default: 50 = 0.5%%)")
    p.add_argument("--size", type=Decimal, default=Decimal("100"), help="Order size per side (default: 100)")
    p.add_argument("--refresh", type=float, default=10.0, help="Refresh interval in seconds (default: 10)")
    p.add_argument("--drift-bps", type=int, default=25, help="Re-quote if mid drifts this many bps (default: 25)")
    p.add_argument("--live", action="store_true", help="Use live mode (default: paper)")
    p.add_argument("--debug", action="store_true", help="Enable debug logging")

    # Clone mode arguments
    clone = p.add_argument_group("clone mode", "Manage a cloned bot with real trading and risk controls")
    clone.add_argument("--clone-id", help="Attach to an existing clone instance by ID")
    clone.add_argument("--clone-from", help="Create a new clone from a species slug (e.g., 'macd', 'momentum')")
    clone.add_argument("--clone-source-type", default="wta_species",
                       choices=["wta_species", "perpetual_agent", "backtest_bot"],
                       help="Clone source type (default: wta_species)")
    clone.add_argument("--clone-name", help="Custom name for a new clone")
    clone.add_argument("--fund-amount", type=float, default=0,
                       help="USDC amount to fund clone from your balance (default: 0 = skip)")
    clone.add_argument("--max-loss", type=Decimal, default=None,
                       help="Emergency stop if session loss exceeds this USDC amount")
    clone.add_argument("--max-position", type=Decimal, default=None,
                       help="Max base token position size")
    clone.add_argument("--withdraw-on-stop", action="store_true",
                       help="Withdraw all clone funds on shutdown")
    clone.add_argument("--equip-skill", action="append", default=[],
                       help="Equip a skill to the clone (repeatable, e.g., --equip-skill market_making --equip-skill rsi_momentum)")

    return p.parse_args()


async def main():
    args = parse_args()

    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    clone_mode = bool(args.clone_id or args.clone_from)
    # Clone mode defaults to live; standalone defaults to paper
    if clone_mode:
        mode = "live"  # clone mode = real trading by default
    else:
        mode = "live" if args.live else "paper"

    # Authenticate
    if args.api_key:
        client = EventTrader(api_key=args.api_key, base_url=args.base_url)
    elif args.email and args.password:
        client = await EventTrader.from_credentials(
            args.email, args.password, base_url=args.base_url,
        )
    elif args.username and args.password:
        # Username login — SDK's from_credentials() requires email format,
        # so we do a manual login and inject the token into the client.
        import httpx
        log.info("Logging in as %s...", args.username)
        async with httpx.AsyncClient(verify=False) as http:
            resp = await http.post(
                f"{args.base_url}/auth/login",
                json={"username": args.username, "password": args.password},
            )
            resp.raise_for_status()
            tokens = resp.json()
        client = EventTrader(base_url=args.base_url)
        # Inject tokens — same pattern the SDK uses internally in JWTAuth.login()
        new_config = client._config.with_tokens(
            tokens["access_token"], tokens.get("refresh_token"),
        )
        client._http.config = new_config
        # Force re-creation of the httpx client with new auth headers
        if client._http._client and not client._http._client.is_closed:
            await client._http._client.aclose()
        client._http._client = None
        log.info("Logged in as %s", args.username)
    else:
        log.error("Provide --api-key, --email/--password, or --username/--password")
        sys.exit(1)

    # --- Clone setup ---
    clone_info: dict | None = None
    clone_id: str | None = None
    clone_manager: CloneBotManager | None = None
    clone_stop_loss_pct: float | None = None
    clone_take_profit_pct: float | None = None
    initial_balance = Decimal("0")

    if clone_mode:
        clone_manager = CloneBotManager(client._http)

        # Create or attach
        if args.clone_from:
            log.info("Creating clone from %s '%s'...", args.clone_source_type, args.clone_from)
            create_resp = await clone_manager.create_clone(
                source_type=args.clone_source_type,
                source_id=args.clone_from,
                name=args.clone_name,
                is_paper=(mode != "live"),
            )
            clone_id = create_resp.get("id", "")
            log.info("Clone created: %s (id: %s)", create_resp.get("name", "?"), clone_id)
        else:
            clone_id = args.clone_id
            log.info("Attaching to clone %s...", clone_id)

        # Load profile
        profile = await clone_manager.get_profile(clone_id)
        clone_info = {
            "id": clone_id,
            "name": profile.get("name") or profile.get("custom_name") or profile.get("bot_name", "?"),
            "source_name": profile.get("source_name") or profile.get("species_name", "?"),
            "type": profile.get("type") or profile.get("source_type", "?"),
            "is_paper": profile.get("is_paper", False),
        }

        # Read clone settings for risk controls
        settings = profile.get("settings") or profile.get("trading_settings") or {}
        clone_stop_loss_pct = settings.get("stop_loss_pct")
        clone_take_profit_pct = settings.get("take_profit_pct")

        # Override mode from clone's is_paper if not explicitly set via --live
        if clone_info["is_paper"] and not args.live:
            mode = "paper"

        # Equip skills
        if args.equip_skill:
            for skill_id in args.equip_skill:
                try:
                    resp = await clone_manager.equip_skill(clone_id, skill_id)
                    log.info("Equipped skill '%s': %s", skill_id, resp.get("message", "OK"))
                except Exception as e:
                    log.warning("Could not equip skill '%s': %s", skill_id, e)

        # Load equipped skills for banner
        equipped_skills: list[str] = []
        try:
            loadout = await clone_manager.get_skills(clone_id)
            for slot_type in ("primary", "secondary", "passive"):
                for s in loadout.get(slot_type, []):
                    skill_name = s.get("name") or s.get("skill_id", "?")
                    equipped_skills.append(skill_name)
                    # Apply market_making custom_params to bot config
                    if s.get("skill_id") == "market_making":
                        params = s.get("custom_params") or {}
                        if "spread_bps" in params and not any(
                            a in sys.argv for a in ("--spread-bps",)
                        ):
                            args.spread_bps = int(params["spread_bps"])
                            log.info("  Skill override: spread_bps=%d", args.spread_bps)
                        if "max_inventory" in params and args.max_position is None:
                            args.max_position = Decimal(str(params["max_inventory"]))
                            log.info("  Skill override: max_position=%s", args.max_position)
        except Exception as e:
            log.debug("Could not load skills: %s", e)
        clone_info["skills"] = equipped_skills

        # Fund the clone
        if args.fund_amount > 0:
            log.info("Funding clone with %.2f USDC...", args.fund_amount)
            fund_resp = await clone_manager.fund(clone_id, args.fund_amount)
            log.info("Funded: %s", fund_resp)

        # Get starting balance
        bal_resp = await clone_manager.get_balance(clone_id)
        available = bal_resp.get("available") or bal_resp.get("balance") or bal_resp.get("total", 0)
        initial_balance = Decimal(str(available))
        log.info("Clone balance: %.2f USDC", initial_balance)

        # Enable trading on the clone
        await clone_manager.toggle_trading(clone_id, True)
        log.info("Clone trading enabled")

    bot = MarketMaker(
        client=client,
        symbol=args.pair,
        spread_bps=args.spread_bps,
        order_size=args.size,
        refresh_interval=args.refresh,
        mode=mode,
        drift_threshold_bps=args.drift_bps,
        max_session_loss=args.max_loss,
        max_position_size=args.max_position,
        clone_stop_loss_pct=clone_stop_loss_pct,
        clone_take_profit_pct=clone_take_profit_pct,
        initial_balance=initial_balance,
    )

    # Graceful shutdown on SIGINT/SIGTERM
    loop = asyncio.get_running_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, lambda: setattr(bot, "running", False))

    try:
        async with client:
            await bot.run(clone_info=clone_info, clone_manager=clone_manager, clone_id=clone_id)

            # Clone shutdown sequence
            if clone_manager and clone_id:
                log.info("Pausing clone trading...")
                try:
                    await clone_manager.toggle_trading(clone_id, False)
                except Exception as e:
                    log.warning("Could not pause clone: %s", e)

                if args.withdraw_on_stop:
                    try:
                        bal_resp = await clone_manager.get_balance(clone_id)
                        available = float(bal_resp.get("available") or bal_resp.get("balance") or bal_resp.get("total", 0))
                        if available > 0:
                            log.info("Withdrawing %.2f USDC from clone...", available)
                            await clone_manager.withdraw(clone_id, available)
                            log.info("Withdrawn successfully")
                    except Exception as e:
                        log.warning("Could not withdraw from clone: %s", e)

                # Log final clone balance
                try:
                    final_bal = await clone_manager.get_balance(clone_id)
                    final_avail = final_bal.get("available") or final_bal.get("balance") or final_bal.get("total", 0)
                    pnl = float(final_avail) - float(initial_balance)
                    log.info("Clone final balance: %.2f USDC (session P&L: %+.2f USDC)", float(final_avail), pnl)
                except Exception as e:
                    log.warning("Could not fetch final clone balance: %s", e)
    except KeyboardInterrupt:
        pass

    log.info("Bot stopped.")


if __name__ == "__main__":
    asyncio.run(main())
