"""
Grok Worker Pool -- Parallel video generation with N browser instances.

Manages a pool of GrokVideoProvider instances, each backed by its own
Camoufox (Firefox) browser + Grok account.  Uses asyncio.Semaphore to
cap concurrency and asyncio.gather for batch dispatch.

Machine budget: 32 GB RAM, ~300-500 MB per browser → max 4-5 concurrent.

Usage:
    async with GrokWorkerPool(max_workers=3) as pool:
        await pool.start()
        results = await pool.generate_batch(scenes, on_progress=cb)
"""

from __future__ import annotations

import asyncio
import logging
import time
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from providers.grok_playwright_provider import GrokPlaywrightProvider as GrokVideoProvider

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------


@dataclass
class SceneResult:
    """Outcome of a single scene generation attempt."""

    scene_number: int
    success: bool
    output_path: Path | None = None
    error: str | None = None
    worker_label: str = ""
    elapsed_seconds: float = 0.0


@dataclass
class PoolStatus:
    """Snapshot of worker-pool health."""

    total_workers: int = 0
    active_workers: int = 0
    failed_workers: int = 0
    scenes_completed: int = 0
    scenes_failed: int = 0


# ---------------------------------------------------------------------------
# Single worker
# ---------------------------------------------------------------------------


class GrokWorker:
    """Wraps one GrokVideoProvider + account metadata.

    The provider is created lazily on first ``generate`` call so that pool
    construction stays cheap and fast.
    """

    def __init__(self, account: dict[str, Any]) -> None:
        self.account_id: int = account["id"]
        self.label: str = account.get("label", f"account-{account['id']}")
        self._sso_token: str = account["sso_token"]
        self._sso_rw_token: str = account["sso_rw_token"]
        self._user_id: str = account["user_id"]

        self._provider: GrokVideoProvider | None = None
        self.busy: bool = False
        self.failed: bool = False
        self.generation_count: int = 0
        self.last_error: str | None = None

    # -- lifecycle ----------------------------------------------------------

    def _ensure_provider(self) -> GrokVideoProvider:
        """Lazy-init: create the provider (and its Chrome) on first use."""
        if self._provider is None:
            logger.info("Worker [%s] creating GrokVideoProvider...", self.label)
            self._provider = GrokVideoProvider(
                sso_token=self._sso_token,
                sso_rw_token=self._sso_rw_token,
                user_id=self._user_id,
                headless=False,
            )
        return self._provider

    async def close(self) -> None:
        """Shut down the underlying provider (Chrome + virtual display)."""
        if self._provider is not None:
            try:
                await self._provider.close()
            except Exception as exc:
                logger.warning("Worker [%s] close error: %s", self.label, exc)
            finally:
                self._provider = None

    # -- generation ---------------------------------------------------------

    async def generate(
        self,
        image_path: str | Path,
        video_prompt: str,
        output_path: str | Path,
        on_progress: Callable[[int], None] | None = None,
    ) -> Path:
        """Generate a video via the wrapped provider.

        Raises on failure and marks the worker as *failed* if the browser
        appears to have crashed (so the pool can skip it next time).
        """
        self.busy = True
        start = time.monotonic()
        try:
            provider = self._ensure_provider()
            result = await provider.generate(
                image_path=image_path,
                video_prompt=video_prompt,
                output_path=output_path,
                on_progress=on_progress,
            )
            self.generation_count += 1
            self.last_error = None
            return result
        except Exception as exc:
            self.last_error = str(exc)
            # Heuristic: if browser died we cannot recover this worker.
            err_lower = str(exc).lower()
            if any(
                tok in err_lower for tok in ("chrome", "firefox", "camoufox", "session", "browser", "devtools")
            ):
                logger.error("Worker [%s] browser crashed: %s", self.label, exc)
                self.failed = True
                await self.close()
            raise
        finally:
            self.busy = False
            elapsed = time.monotonic() - start
            logger.debug(
                "Worker [%s] generation took %.1fs (total=%d)",
                self.label,
                elapsed,
                self.generation_count,
            )


# ---------------------------------------------------------------------------
# Worker pool
# ---------------------------------------------------------------------------


class GrokWorkerPool:
    """Manages N parallel GrokWorker instances for concurrent I2V generation.

    Concurrency is capped by an ``asyncio.Semaphore`` so at most
    ``max_workers`` Chrome browsers run simultaneously.
    """

    def __init__(self, max_workers: int = 3) -> None:
        self.max_workers = max_workers
        self._workers: list[GrokWorker] = []
        self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_workers)
        self._scenes_completed: int = 0
        self._scenes_failed: int = 0
        self._started: bool = False

    # -- lifecycle ----------------------------------------------------------

    async def start(self) -> None:
        """Fetch active Grok accounts from DB and create workers.

        If fewer accounts are available than ``max_workers``, the pool
        adjusts downward gracefully.
        """
        if self._started:
            logger.warning("GrokWorkerPool.start() called twice — skipping")
            return

        accounts = await get_available_accounts(self.max_workers)

        if not accounts:
            raise RuntimeError(
                "No active Grok accounts found in database. "
                "Add accounts via the admin panel before starting the pool."
            )

        for acct in accounts:
            self._workers.append(GrokWorker(acct))

        # If fewer accounts than requested, tighten the semaphore.
        actual = len(self._workers)
        if actual < self.max_workers:
            logger.warning(
                "Only %d active account(s) available (requested %d) — "
                "pool concurrency reduced",
                actual,
                self.max_workers,
            )
            self._semaphore = asyncio.Semaphore(actual)

        self._started = True
        logger.info(
            "GrokWorkerPool started: %d worker(s) [%s]",
            actual,
            ", ".join(w.label for w in self._workers),
        )

    async def close(self) -> None:
        """Shut down every worker (Chrome instances)."""
        close_tasks = [w.close() for w in self._workers]
        if close_tasks:
            await asyncio.gather(*close_tasks, return_exceptions=True)
        self._workers.clear()
        self._started = False
        logger.info("GrokWorkerPool closed")

    async def __aenter__(self) -> GrokWorkerPool:
        return self

    async def __aexit__(self, *exc_info: Any) -> None:
        await self.close()

    # -- status -------------------------------------------------------------

    def get_pool_status(self) -> dict[str, int]:
        """Return a snapshot of pool health metrics."""
        active = sum(1 for w in self._workers if not w.failed)
        failed = sum(1 for w in self._workers if w.failed)
        return {
            "total_workers": len(self._workers),
            "active_workers": active,
            "failed_workers": failed,
            "scenes_completed": self._scenes_completed,
            "scenes_failed": self._scenes_failed,
        }

    # -- single scene -------------------------------------------------------

    async def generate_scene(
        self,
        scene_number: int,
        image_path: str | Path,
        video_prompt: str,
        output_path: str | Path,
        on_progress: Callable[[int, str, int, int, str | None], None] | None = None,
    ) -> SceneResult:
        """Generate one scene, dispatching to the next available worker.

        The semaphore ensures we never exceed ``max_workers`` concurrent
        Chrome browsers.  If the assigned worker fails, we attempt the
        remaining healthy workers before giving up.
        """
        async with self._semaphore:
            return await self._try_generate(
                scene_number,
                image_path,
                video_prompt,
                output_path,
                on_progress,
            )

    async def _try_generate(
        self,
        scene_number: int,
        image_path: str | Path,
        video_prompt: str,
        output_path: str | Path,
        on_progress: Callable[[int, str, int, int, str | None], None] | None = None,
    ) -> SceneResult:
        """Try each healthy worker in turn until one succeeds or all fail."""
        last_error = ""

        for worker in self._workers:
            if worker.failed or worker.busy:
                continue

            if on_progress:
                on_progress(scene_number, "generating", 0, 100, None)

            start = time.monotonic()
            try:

                def _make_progress_cb(
                    sn: int,
                    cb: Callable[[int, str, int, int, str | None], None],
                ) -> Callable[[int], None]:
                    """Adapt per-percent callback to pool's 5-arg callback."""

                    def _inner(pct: int) -> None:
                        cb(sn, "generating", pct, 100, None)

                    return _inner

                inner_cb = (
                    _make_progress_cb(scene_number, on_progress)
                    if on_progress
                    else None
                )

                result_path = await worker.generate(
                    image_path=image_path,
                    video_prompt=video_prompt,
                    output_path=output_path,
                    on_progress=inner_cb,
                )
                elapsed = time.monotonic() - start
                self._scenes_completed += 1

                if on_progress:
                    on_progress(scene_number, "completed", 100, 100, None)

                logger.info(
                    "Scene %d completed by [%s] in %.1fs",
                    scene_number,
                    worker.label,
                    elapsed,
                )
                return SceneResult(
                    scene_number=scene_number,
                    success=True,
                    output_path=result_path,
                    worker_label=worker.label,
                    elapsed_seconds=elapsed,
                )

            except Exception as exc:
                elapsed = time.monotonic() - start
                last_error = str(exc)
                logger.error(
                    "Scene %d failed on worker [%s] (%.1fs): %s",
                    scene_number,
                    worker.label,
                    elapsed,
                    exc,
                )
                # If the worker isn't marked failed it was a generation-level
                # error — don't retry on the same worker, try the next one.
                continue

        # All workers exhausted.
        self._scenes_failed += 1
        if on_progress:
            on_progress(scene_number, "failed", 0, 100, last_error)

        logger.error(
            "Scene %d failed on ALL workers. Last error: %s",
            scene_number,
            last_error,
        )
        return SceneResult(
            scene_number=scene_number,
            success=False,
            error=last_error,
        )

    # -- batch generation ---------------------------------------------------

    async def generate_batch(
        self,
        scenes: list[dict[str, Any]],
        on_progress: Callable[[int, str, int, int, str | None], None] | None = None,
    ) -> list[SceneResult]:
        """Generate all scenes concurrently, limited by the semaphore.

        Each *scene* dict must contain:
            scene_number  (int)
            image_path    (str | Path)
            video_prompt  (str)
            output_path   (str | Path)

        Returns a list of :class:`SceneResult` objects (one per scene, in
        the same order as *scenes*).
        """
        if not self._started:
            raise RuntimeError("Pool not started — call await pool.start() first")

        if not scenes:
            return []

        active_count = sum(1 for w in self._workers if not w.failed)
        if active_count == 0:
            raise RuntimeError("All workers have failed — no healthy browsers left")

        logger.info(
            "Batch: %d scene(s) with %d active worker(s)",
            len(scenes),
            active_count,
        )

        tasks = [
            self.generate_scene(
                scene_number=s["scene_number"],
                image_path=s["image_path"],
                video_prompt=s["video_prompt"],
                output_path=s["output_path"],
                on_progress=on_progress,
            )
            for s in scenes
        ]

        results: list[SceneResult] = await asyncio.gather(*tasks)  # type: ignore[arg-type]
        succeeded = sum(1 for r in results if r.success)
        failed = sum(1 for r in results if not r.success)
        logger.info("Batch done: %d succeeded, %d failed", succeeded, failed)
        return results


# ---------------------------------------------------------------------------
# DB helper — fetch accounts without incrementing usage
# ---------------------------------------------------------------------------


async def get_available_accounts(n: int) -> list[dict[str, Any]]:
    """Fetch up to *n* active Grok accounts, ordered by least-recently-used.

    Unlike ``get_next_grok_account()`` this is a **read-only** query — it
    does NOT increment usage counters. Used only for pool initialisation so
    each worker can be pre-assigned a unique account.
    """
    from web.database_sqlite import get_db

    db = await get_db()
    cur = await db.execute(
        """
        SELECT *
        FROM grok_accounts
        WHERE is_active = 1
        ORDER BY last_used_at IS NULL DESC, last_used_at ASC, usage_count ASC
        LIMIT ?
        """,
        (n,),
    )
    rows = await cur.fetchall()
    return [dict(r) for r in rows]
