"""
Google ImageFX / Flow Image Provider – generates images via the
aisandbox-pa.googleapis.com API using manually extracted browser cookies.

Auth flow (reverse-engineered):
  1. Exchange browser cookies for a Bearer access_token via
     https://labs.google/fx/api/auth/session
  2. Call https://aisandbox-pa.googleapis.com/v1:runImageFx with Bearer auth

HOW TO GET YOUR COOKIES:
1. Open https://labs.google/fx/tools/image-fx in Chrome
2. Open DevTools (F12) → Application → Cookies → https://labs.google
3. Copy `__Secure-next-auth.session-token` and
   `__Host-next-auth.csrf-token`
4. Paste them into your .env as GOOGLE_FLOW_SESSION_TOKEN and
   GOOGLE_FLOW_CSRF_TOKEN respectively.
"""

from __future__ import annotations

import asyncio
import base64
import json
import logging
import os
import random
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import httpx

from providers.ports import ImageProvider

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SESSION_URL = "https://labs.google/fx/api/auth/session"
GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1:runImageFx"

ASPECT_RATIOS = {
    "portrait": "IMAGE_ASPECT_RATIO_PORTRAIT",
    "landscape": "IMAGE_ASPECT_RATIO_LANDSCAPE",
    "square": "IMAGE_ASPECT_RATIO_SQUARE",
}

# Model mapping — includes both legacy ImageFX names and newer Flow API names
IMAGE_MODELS = {
    # Legacy ImageFX models
    "nano_banana_pro": "IMAGEN_3_5",
    "imagen_3_5": "IMAGEN_3_5",
    "imagen_3_5_fast": "IMAGEN_3_5_FAST",
    "imagen_3": "IMAGEN_3",
    "imagen_3_fast": "IMAGEN_3_FAST",
    # Newer Flow API models (may or may not work on v1:runImageFx)
    "gem_pix_2": "GEM_PIX_2",
    "imagen_4": "IMAGEN_4",
    "nano_banana": "NANO_BANANA",
    "nano_banana_pro_flow": "NANO_BANANA_PRO",
}

# Ordered list for UI display: (key, display_label)
IMAGE_MODEL_OPTIONS = [
    ("gem_pix_2", "Gemini Pix 2 (newest)"),
    ("imagen_4", "Imagen 4"),
    ("nano_banana_pro_flow", "Nano Banana Pro (Flow)"),
    ("nano_banana_pro", "Imagen 3.5 (Nano Banana Pro)"),
    ("imagen_3_5_fast", "Imagen 3.5 Fast"),
    ("imagen_3", "Imagen 3"),
]

DEFAULT_MODEL = "IMAGEN_3_5"  # Current default — Nano Banana Pro legacy

DEFAULT_HEADERS = {
    "Origin": "https://labs.google",
    "Referer": "https://labs.google/fx/tools/image-fx",
    "Content-Type": "application/json",
}


class GoogleFlowImageProvider(ImageProvider):
    """Generates images via Google ImageFX (labs.google/fx) using browser cookies."""

    name = "google_flow"

    def __init__(self):
        self.session_token = os.getenv("GOOGLE_FLOW_SESSION_TOKEN", "")
        self.csrf_token = os.getenv("GOOGLE_FLOW_CSRF_TOKEN", "")
        self.aspect_ratio = ASPECT_RATIOS.get(
            os.getenv("GOOGLE_FLOW_ASPECT_RATIO", "portrait").lower(),
            "IMAGE_ASPECT_RATIO_PORTRAIT",
        )
        model_key = os.getenv("GOOGLE_FLOW_MODEL", "nano_banana_pro").lower()
        self.model = IMAGE_MODELS.get(model_key, model_key.upper())

        if not self.session_token:
            raise EnvironmentError(
                "GOOGLE_FLOW_SESSION_TOKEN is required. "
                "Extract it from Chrome DevTools → Application → Cookies → "
                "https://labs.google → __Secure-next-auth.session-token"
            )
        if not self.csrf_token:
            raise EnvironmentError(
                "GOOGLE_FLOW_CSRF_TOKEN is required. "
                "Extract it from Chrome DevTools → Application → Cookies → "
                "https://labs.google → __Host-next-auth.csrf-token"
            )

        self._access_token: str | None = None
        self._token_expiry: datetime | None = None

        self.client = httpx.Client(timeout=60, follow_redirects=True)

    # ------------------------------------------------------------------
    # Cookie helpers
    # ------------------------------------------------------------------

    def _cookie_string(self) -> str:
        """Build the Cookie header value."""
        return (
            f"__Secure-next-auth.session-token={self.session_token}; "
            f"__Host-next-auth.csrf-token={self.csrf_token}"
        )

    def _update_session_cookie(self, response: httpx.Response) -> None:
        """Persist rotated session token from Set-Cookie headers."""
        for cookie_header in response.headers.get_list("set-cookie"):
            if "__Secure-next-auth.session-token=" in cookie_header:
                token = cookie_header.split("__Secure-next-auth.session-token=")[1]
                token = token.split(";")[0]
                if token and token != self.session_token:
                    self.session_token = token
                    logger.debug("Session cookie rotated")
                    self._persist_token(token)
                break

    def _persist_token(self, token: str) -> None:
        """Try to update the .env file with the new session token."""
        env_path = Path(__file__).resolve().parent.parent / ".env"
        if not env_path.exists():
            return
        try:
            content = env_path.read_text(encoding="utf-8")
            if "GOOGLE_FLOW_SESSION_TOKEN=" in content:
                lines = content.splitlines()
                new_lines = []
                for line in lines:
                    if line.startswith("GOOGLE_FLOW_SESSION_TOKEN="):
                        new_lines.append(f"GOOGLE_FLOW_SESSION_TOKEN={token}")
                    else:
                        new_lines.append(line)
                env_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
                logger.info("Updated session token in .env")
        except Exception as e:
            logger.warning("Could not persist rotated token to .env: %s", e)

    # ------------------------------------------------------------------
    # Auth: exchange cookies → Bearer access_token
    # ------------------------------------------------------------------

    def _refresh_access_token(self) -> str:
        """Call /fx/api/auth/session to get a short-lived access_token."""
        if (
            self._access_token
            and self._token_expiry
            and datetime.now(timezone.utc) < self._token_expiry
        ):
            return self._access_token

        logger.info("Refreshing access token via session endpoint...")
        resp = self.client.get(
            SESSION_URL,
            headers={
                **DEFAULT_HEADERS,
                "Cookie": self._cookie_string(),
            },
        )
        self._update_session_cookie(resp)
        resp.raise_for_status()

        data = resp.json()
        self._access_token = data.get("access_token") or data.get("accessToken")
        expires_str = data.get("expires")

        if not self._access_token:
            raise RuntimeError(
                f"No access_token in session response. Keys: {list(data.keys())}. "
                "Cookies may be expired — re-extract from Chrome."
            )

        if expires_str:
            try:
                self._token_expiry = datetime.fromisoformat(
                    expires_str.replace("Z", "+00:00")
                )
            except ValueError:
                self._token_expiry = None

        logger.info(
            "Access token obtained (expires %s)",
            self._token_expiry.isoformat() if self._token_expiry else "unknown",
        )
        return self._access_token

    def _auth_headers(self) -> dict[str, str]:
        """Return headers with Bearer auth for the generation API."""
        token = self._refresh_access_token()
        return {
            **DEFAULT_HEADERS,
            "Cookie": self._cookie_string(),
            "Authorization": f"Bearer {token}",
        }

    # ------------------------------------------------------------------
    # Image generation
    # ------------------------------------------------------------------

    def _build_generate_body(
        self, prompt: str, seed: int | None = None, num_images: int = 1
    ) -> dict:
        """Build the JSON body for v1:runImageFx."""
        if seed is None:
            seed = random.randint(0, 999999)

        return {
            "userInput": {
                "candidatesCount": num_images,
                "prompts": [prompt],
                "seed": seed,
            },
            "clientContext": {
                "sessionId": f";{int(time.time() * 1000)}",
                "tool": "IMAGE_FX",
            },
            "modelInput": {
                "modelNameType": self.model,
            },
            "aspectRatio": self.aspect_ratio,
        }

    def _generate_sync(
        self,
        prompt: str,
        output_path: str | Path,
        seed: int | None = None,
        num_images: int = 1,
    ) -> Path:
        """Generate an image from *prompt* and save to *output_path*.

        Returns the Path of the saved image.
        """
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        body = self._build_generate_body(prompt, seed=seed, num_images=num_images)

        logger.info(
            "Generating image via Google ImageFX  model=%s  aspect=%s → %s",
            self.model,
            self.aspect_ratio,
            output_path,
        )

        resp = self.client.post(
            GENERATE_URL,
            json=body,
            headers=self._auth_headers(),
        )
        self._update_session_cookie(resp)
        resp.raise_for_status()

        result = resp.json()
        logger.debug("Generate response keys: %s", list(result.keys()))

        # Extract first generated image
        panels = result.get("imagePanels") or []
        if not panels:
            raise RuntimeError(
                f"No imagePanels in response. Full response: {json.dumps(result)[:500]}"
            )

        images = panels[0].get("generatedImages") or []
        if not images:
            raise RuntimeError(
                f"No generatedImages in first panel. Panel keys: {list(panels[0].keys())}"
            )

        first_image = images[0]
        encoded = first_image.get("encodedImage")
        media_id = first_image.get("mediaGenerationId")

        if encoded:
            img_data = base64.b64decode(encoded)
        else:
            raise RuntimeError(
                f"No encodedImage in generated image. Keys: {list(first_image.keys())}"
            )

        with open(output_path, "wb") as f:
            f.write(img_data)

        logger.info("Image saved: %s (%d bytes)", output_path, len(img_data))

        # Store media_id for potential I2V use
        self._last_media_id = media_id

        return output_path

    async def generate(
        self,
        prompt: str,
        output_path: Path,
        **kwargs: Any,
    ) -> Path:
        """Generate an image asynchronously."""
        return await asyncio.to_thread(
            self._generate_sync, prompt, output_path, **kwargs
        )

    def generate_with_media_id(
        self,
        prompt: str,
        output_path: str | Path,
        seed: int | None = None,
        num_images: int = 1,
    ) -> tuple[Path, str]:
        """Generate an image and return both the path and the mediaGenerationId.

        The mediaGenerationId can be used as startImage.mediaId for I2V video generation.
        """
        path = self._generate_sync(
            prompt, output_path, seed=seed, num_images=num_images
        )
        return path, self._last_media_id

    def close(self):
        """Close the HTTP client."""
        self.client.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()
