"""
FlowHttpEngine — HTTP-based engine for Google Labs Flow image generation.

Uses direct HTTP calls instead of browser automation to avoid bot detection.
Based on reverse-engineering of the Flow API (Feb 2026).

Auth flow:
1. Use browser cookies to call /fx/api/auth/session
2. Get Bearer access_token from response
3. Use Bearer token to call image generation API

Endpoints:
- Session: https://labs.google/fx/api/auth/session
- Generate: https://aisandbox-pa.googleapis.com/v1:runImageFx
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import random
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Callable

import httpx

logger = logging.getLogger(__name__)

# ── Constants ─────────────────────────────────────────────────────────────────
SESSION_URL = "https://labs.google/fx/api/auth/session"
GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1:runImageFx"

ASPECT_RATIOS = {
    "portrait": "IMAGE_ASPECT_RATIO_PORTRAIT",
    "landscape": "IMAGE_ASPECT_RATIO_LANDSCAPE",
    "square": "IMAGE_ASPECT_RATIO_SQUARE",
}

# Model mapping
IMAGE_MODELS = {
    "nano_banana_pro": "IMAGEN_3_5",
    "imagen_3_5": "IMAGEN_3_5",
    "imagen_3_5_fast": "IMAGEN_3_5_FAST",
    "imagen_3": "IMAGEN_3",
    "imagen_3_fast": "IMAGEN_3_FAST",
    "gem_pix_2": "GEM_PIX_2",
    "imagen_4": "IMAGEN_4",
    "nano_banana": "NANO_BANANA",
    "nano_banana_pro_flow": "NANO_BANANA_PRO",
}

DEFAULT_MODEL = "IMAGEN_3_5"

DEFAULT_HEADERS = {
    "Origin": "https://labs.google",
    "Referer": "https://labs.google/fx/tools/flow",
    "Content-Type": "application/json",
    "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}

# Timing constants
GENERATION_COOLDOWN = int(os.getenv("FLOW_GENERATION_COOLDOWN", "4"))
BATCH_COOLDOWN = int(os.getenv("FLOW_BATCH_COOLDOWN", "2"))
MAX_RETRIES = 4
RETRY_BACKOFF = [5, 10, 20, 30]
DEFAULT_BATCH_SIZE = 10


class FlowHttpEngine:
    """
    HTTP-based engine for Flow image generation.
    
    Uses direct HTTP calls instead of browser automation.
    
    Usage (async):
        engine = FlowHttpEngine(cookies_json)
        await engine.start()
        path = await engine.generate("a cat", Path("out.png"))
        await engine.stop()
    """

    def __init__(self, cookies: str | list[dict]):
        """
        Initialize the engine with cookies.
        
        Args:
            cookies: Either a JSON string of cookies array, or a list of cookie dicts.
                     Each cookie dict should have 'name' and 'value' keys.
        """
        self._cookies_raw = cookies
        self._cookie_string: str = ""
        self._access_token: str | None = None
        self._token_expiry: datetime | None = None
        self._http_client: httpx.AsyncClient | None = None
        self._ready = False
        self._lock = asyncio.Lock()
        self._last_media_id: str | None = None
        
        # Parse cookies
        self._parse_cookies()
        
        # Get model and aspect ratio from env
        model_key = os.getenv("GOOGLE_FLOW_MODEL", "nano_banana_pro").lower()
        self.model = IMAGE_MODELS.get(model_key, model_key.upper())
        self.aspect_ratio = ASPECT_RATIOS.get(
            os.getenv("GOOGLE_FLOW_ASPECT_RATIO", "portrait").lower(),
            "IMAGE_ASPECT_RATIO_PORTRAIT",
        )

    def _parse_cookies(self) -> None:
        """Parse cookies from JSON or list format to HTTP cookie string."""
        if isinstance(self._cookies_raw, str):
            try:
                cookies_list = json.loads(self._cookies_raw)
            except json.JSONDecodeError:
                # Assume it's already a cookie string
                self._cookie_string = self._cookies_raw
                return
        else:
            cookies_list = self._cookies_raw
        
        # Convert list of cookie dicts to HTTP cookie string
        cookie_parts = []
        for cookie in cookies_list:
            name = cookie.get("name", "")
            value = cookie.get("value", "")
            if name and value:
                cookie_parts.append(f"{name}={value}")
        
        self._cookie_string = "; ".join(cookie_parts)
        logger.info("FlowHttpEngine: Parsed %d cookies", len(cookie_parts))

    @property
    def is_running(self) -> bool:
        return self._ready and self._http_client is not None

    @property
    def last_media_id(self) -> str | None:
        return self._last_media_id

    # ── Lifecycle ─────────────────────────────────────────────────────────────

    async def start(self) -> bool:
        """Initialize HTTP client and refresh session token."""
        if self._ready:
            return True
        
        try:
            self._http_client = httpx.AsyncClient(
                timeout=120.0,
                follow_redirects=True,
            )
            
            # Refresh session to get Bearer token
            await self._refresh_session()
            
            self._ready = True
            logger.info("FlowHttpEngine: Ready — model=%s", self.model)
            return True
            
        except Exception as e:
            logger.error("FlowHttpEngine: Start failed — %s", e)
            await self.stop()
            return False

    async def stop(self) -> None:
        """Clean up HTTP client."""
        self._ready = False
        if self._http_client:
            await self._http_client.aclose()
            self._http_client = None
        logger.info("FlowHttpEngine: Stopped")

    async def _refresh_session(self) -> None:
        """Refresh the Bearer access token from session endpoint."""
        if not self._http_client:
            raise RuntimeError("HTTP client not initialized")
        
        headers = {
            **DEFAULT_HEADERS,
            "Cookie": self._cookie_string,
        }
        
        response = await self._http_client.get(SESSION_URL, headers=headers)
        
        if response.status_code != 200:
            raise RuntimeError(f"Session refresh failed: HTTP {response.status_code}")
        
        data = response.json()
        
        access_token = data.get("access_token")
        expires = data.get("expires")
        
        if not access_token:
            raise RuntimeError(f"No access_token in session response: {list(data.keys())}")
        
        self._access_token = access_token
        if expires:
            self._token_expiry = datetime.fromisoformat(expires.replace("Z", "+00:00"))
        
        logger.info("FlowHttpEngine: Session refreshed, token expires at %s", self._token_expiry)

    def _is_token_expired(self) -> bool:
        """Check if the access token is expired (with 30s buffer)."""
        if not self._access_token or not self._token_expiry:
            return True
        
        buffer = 30  # seconds
        return self._token_expiry <= datetime.now(timezone.utc) - __import__('datetime').timedelta(seconds=buffer)

    def _get_auth_headers(self) -> dict:
        """Get headers with Bearer token for authenticated requests."""
        if not self._access_token:
            raise RuntimeError("No access token available")
        
        return {
            **DEFAULT_HEADERS,
            "Cookie": self._cookie_string,
            "Authorization": f"Bearer {self._access_token}",
        }

    # ── Image Generation ──────────────────────────────────────────────────────

    async def generate(
        self,
        prompt: str,
        output_path: Path | str,
        aspect_ratio: str | None = None,
        seed: int | None = None,
    ) -> Path:
        """Generate an image using Flow API."""
        async with self._lock:
            return await self._generate_impl(
                prompt, Path(output_path), aspect_ratio or self.aspect_ratio, seed
            )

    async def generate_batch(
        self,
        items: list[tuple[str, Path | str]],
        aspect_ratio: str | None = None,
        on_progress: Optional[Callable[[int, int], None]] = None,
    ) -> list[Path | None]:
        """Generate multiple images with minimal cooldown between them.
        
        Args:
            items: List of (prompt, output_path) tuples.
            aspect_ratio: Aspect ratio for all images.
            on_progress: Optional callback(current, total) for progress updates.
        
        Returns:
            List of Paths to generated images (None for failed items).
        """
        async with self._lock:
            results = []
            total = len(items)
            ar = aspect_ratio or self.aspect_ratio
            
            for idx, (prompt, output_path) in enumerate(items):
                try:
                    path = await self._generate_impl(prompt, Path(output_path), ar, None)
                    results.append(path)
                    
                    if on_progress:
                        on_progress(idx + 1, total)
                    
                    # Short cooldown within batch
                    if idx < total - 1:
                        await asyncio.sleep(BATCH_COOLDOWN)
                        
                except Exception as e:
                    logger.error("FlowHttpEngine: Batch item %d/%d failed: %s", idx + 1, total, e)
                    results.append(None)
            
            # Full cooldown at end of batch
            await asyncio.sleep(GENERATION_COOLDOWN)
            return results

    async def _generate_impl(
        self,
        prompt: str,
        output_path: Path,
        aspect_ratio: str,
        seed: int | None,
    ) -> Path:
        """Internal implementation of image generation."""
        if not self._ready:
            raise RuntimeError("FlowHttpEngine not started.")
        
        # Refresh token if expired
        if self._is_token_expired():
            await self._refresh_session()
        
        output_path.parent.mkdir(parents=True, exist_ok=True)
        logger.info("FlowHttpEngine: Generating → %s", output_path.name)
        
        # Build request body
        body = self._build_request_body(prompt, aspect_ratio, seed)
        
        last_error = None
        for attempt in range(MAX_RETRIES + 1):
            if attempt > 0:
                backoff = RETRY_BACKOFF[min(attempt - 1, len(RETRY_BACKOFF) - 1)]
                logger.info("FlowHttpEngine: Retry %d, waiting %ds...", attempt, backoff)
                await asyncio.sleep(backoff)
                # Refresh session on retry
                await self._refresh_session()
            
            try:
                response = await self._http_client.post(
                    GENERATE_URL,
                    json=body,
                    headers=self._get_auth_headers(),
                )
                
                if response.status_code != 200:
                    last_error = f"HTTP {response.status_code}: {response.text[:200]}"
                    logger.error("FlowHttpEngine: API error — %s", last_error)
                    continue
                
                data = response.json()
                
                # Extract image from response
                # Response format: { imagePanels: [{ generatedImages: [{ encodedImage: "base64...", mediaGenerationId: "..." }] }] }
                
                image_saved = False
                media_id = None
                
                # Try ImageFX format (imagePanels with encodedImage)
                panels = data.get("imagePanels", [])
                if panels:
                    images = panels[0].get("generatedImages", [])
                    if images:
                        img = images[0]
                        media_id = img.get("mediaGenerationId")
                        
                        # Check for encoded image (base64)
                        encoded = img.get("encodedImage")
                        if encoded:
                            import base64
                            image_bytes = base64.b64decode(encoded)
                            output_path.write_bytes(image_bytes)
                            image_saved = True
                            logger.info("FlowHttpEngine: Image saved (base64) → %s", output_path.name)
                
                # Try Flow format (media with fifeUrl)
                if not image_saved:
                    media_list = data.get("media", [])
                    if media_list:
                        first_media = media_list[0]
                        image_data = first_media.get("image", {}).get("generatedImage", {})
                        image_url = image_data.get("fifeUrl")
                        media_id = image_data.get("mediaGenerationId")
                        
                        if image_url:
                            dl_resp = await self._http_client.get(image_url)
                            if dl_resp.status_code == 200:
                                output_path.write_bytes(dl_resp.content)
                                image_saved = True
                
                if not image_saved:
                    last_error = f"No image data in response: {list(data.keys())}"
                    logger.error("FlowHttpEngine: %s", last_error)
                    continue
                
                if media_id:
                    self._last_media_id = media_id
                
                size = output_path.stat().st_size
                logger.info("FlowHttpEngine: Image saved → %s (%d bytes)", output_path.name, size)
                return output_path
                
            except Exception as e:
                last_error = str(e)
                logger.error("FlowHttpEngine: Request failed — %s", e)
                continue
        
        raise RuntimeError(f"Generation failed after {MAX_RETRIES} retries: {last_error}")

    def _build_request_body(self, prompt: str, aspect_ratio: str, seed: int | None) -> dict:
        """Build the request body for image generation."""
        body = {
            "userInput": {
                "candidatesCount": 1,
                "prompts": [prompt],
                "seed": seed if seed is not None else random.randint(0, 2**31 - 1),
            },
            "aspectRatio": aspect_ratio,
            "modelInput": {
                "modelNameType": self.model,
            },
            "clientContext": {
                "sessionId": f";{int(time.time() * 1000)}",
                "tool": "IMAGE_FX",
            },
        }
        return body


# ── Module-level singleton and helpers ────────────────────────────────────────

_engine: FlowHttpEngine | None = None
_engine_lock = asyncio.Lock()


async def get_engine() -> FlowHttpEngine:
    """Get or create the singleton FlowHttpEngine instance."""
    global _engine
    async with _engine_lock:
        if _engine is None or not _engine.is_running:
            # Get cookies from environment
            cookies_json = os.getenv("GOOGLE_FLOW_COOKIES", "")
            
            if not cookies_json:
                # Fallback: try to build cookie string from individual tokens
                session_token = os.getenv("GOOGLE_FLOW_SESSION_TOKEN", "")
                csrf_token = os.getenv("GOOGLE_FLOW_CSRF_TOKEN", "")
                
                if session_token and csrf_token:
                    cookies_json = json.dumps([
                        {"name": "__Secure-next-auth.session-token", "value": session_token},
                        {"name": "__Host-next-auth.csrf-token", "value": csrf_token},
                        {"name": "__Secure-next-auth.callback-url", "value": "https%3A%2F%2Flabs.google%2Ffx%2Fpt%2Ftools%2Fflow"},
                    ])
                else:
                    raise RuntimeError(
                        "GOOGLE_FLOW_COOKIES or (GOOGLE_FLOW_SESSION_TOKEN + GOOGLE_FLOW_CSRF_TOKEN) required"
                    )
            
            _engine = FlowHttpEngine(cookies_json)
            ok = await _engine.start()
            if not ok:
                raise RuntimeError("Failed to start FlowHttpEngine")
        
        return _engine


async def stop_engine() -> None:
    """Stop the singleton engine."""
    global _engine
    async with _engine_lock:
        if _engine:
            await _engine.stop()
            _engine = None


async def generate_image(
    prompt: str,
    output_path: Path | str,
    aspect_ratio: str = "IMAGE_ASPECT_RATIO_PORTRAIT",
) -> Path:
    """High-level: generate a single image using the Flow HTTP engine."""
    engine = await get_engine()
    return await engine.generate(prompt, output_path, aspect_ratio=aspect_ratio)


async def generate_images_batch(
    items: list[tuple[str, Path | str]],
    aspect_ratio: str = "IMAGE_ASPECT_RATIO_PORTRAIT",
    batch_size: int = 10,
    on_progress: Optional[Callable[[int, int], None]] = None,
    on_batch_complete: Optional[Callable[[int, int], None]] = None,
) -> list[Path | None]:
    """High-level: generate multiple images in batches."""
    engine = await get_engine()
    
    all_results = []
    total_items = len(items)
    total_batches = (total_items + batch_size - 1) // batch_size
    
    for batch_num in range(total_batches):
        start_idx = batch_num * batch_size
        end_idx = min(start_idx + batch_size, total_items)
        batch_items = items[start_idx:end_idx]
        
        logger.info("FlowHttpEngine: Processing batch %d/%d (%d items)", 
                   batch_num + 1, total_batches, len(batch_items))
        
        def batch_progress(current, total):
            if on_progress:
                global_current = start_idx + current
                on_progress(global_current, total_items)
        
        batch_results = await engine.generate_batch(
            batch_items,
            aspect_ratio=aspect_ratio,
            on_progress=batch_progress,
        )
        all_results.extend(batch_results)
        
        if on_batch_complete:
            on_batch_complete(batch_num + 1, total_batches)
    
    return all_results
