"""
Google Flow API v2 Provider — uses the newer flowMedia:batchGenerateImages endpoint.

This provider accesses newer models like GEM_PIX_2 and IMAGEN_4 that are NOT available
on the legacy v1:runImageFx endpoint.

Requirements (from browser DevTools → Network tab):
  - GOOGLE_FLOW_PROJECT_ID: Your Google Labs project ID
  - GOOGLE_FLOW_SESSION_TOKEN: __Secure-next-auth.session-token cookie
  - GOOGLE_FLOW_CSRF_TOKEN: __Host-next-auth.csrf-token cookie
  - GOOGLE_FLOW_RECAPTCHA_TOKEN: reCAPTCHA Enterprise token (from network requests)

HOW TO GET YOUR TOKENS:
1. Open https://labs.google/fx/tools/image-fx in Chrome
2. Open DevTools (F12) → Network tab
3. Generate an image in the UI
4. Find the POST to aisandbox-pa.googleapis.com/.../flowMedia:batchGenerateImages
5. From the request payload, extract:
   - client_context.project_id → GOOGLE_FLOW_PROJECT_ID
   - client_context.recaptcha_context.token → GOOGLE_FLOW_RECAPTCHA_TOKEN
6. Cookies are the same as the legacy provider.
"""

from __future__ import annotations

import asyncio
import base64
import json
import logging
import os
import random
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional

import httpx

from providers.ports import ImageProvider

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SESSION_URL = "https://labs.google/fx/api/auth/session"
FLOW_GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1/projects/{project_id}/flowMedia:batchGenerateImages"

# Also support the legacy endpoint as fallback
LEGACY_GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1:runImageFx"

ASPECT_RATIOS = {
    "portrait": "IMAGE_ASPECT_RATIO_PORTRAIT",
    "landscape": "IMAGE_ASPECT_RATIO_LANDSCAPE",
    "square": "IMAGE_ASPECT_RATIO_SQUARE",
}

# Flow API models
FLOW_MODELS = {
    "gem_pix_2": "GEM_PIX_2",
    "imagen_4": "IMAGEN_4",
    "nano_banana_pro": "NANO_BANANA_PRO",
    "nano_banana": "NANO_BANANA",
}

# Legacy models (for fallback)
LEGACY_MODELS = {
    "imagen_3_5": "IMAGEN_3_5",
    "imagen_3_5_fast": "IMAGEN_3_5_FAST",
    "imagen_3": "IMAGEN_3",
}

DEFAULT_FLOW_MODEL = "GEM_PIX_2"

DEFAULT_HEADERS = {
    "Origin": "https://labs.google",
    "Referer": "https://labs.google/fx/tools/image-fx",
    "Content-Type": "application/json",
    "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}


class GoogleFlowV2ImageProvider(ImageProvider):
    """Generates images via Google Flow API (newer endpoint with GEM_PIX_2/IMAGEN_4 support)."""

    name = "google_flow_v2"

    def __init__(self):
        self.session_token = os.getenv("GOOGLE_FLOW_SESSION_TOKEN", "")
        self.csrf_token = os.getenv("GOOGLE_FLOW_CSRF_TOKEN", "")
        self.project_id = os.getenv("GOOGLE_FLOW_PROJECT_ID", "")
        self.recaptcha_token = os.getenv("GOOGLE_FLOW_RECAPTCHA_TOKEN", "")

        self.aspect_ratio = ASPECT_RATIOS.get(
            os.getenv("GOOGLE_FLOW_ASPECT_RATIO", "portrait").lower(),
            "IMAGE_ASPECT_RATIO_PORTRAIT",
        )

        model_key = os.getenv("GOOGLE_FLOW_MODEL", "gem_pix_2").lower()
        all_models = {**FLOW_MODELS, **LEGACY_MODELS}
        self.model = all_models.get(model_key, model_key.upper())

        # Determine if we can use the Flow API or must fall back to legacy
        # Flow API needs at least project_id; recaptcha is optional (may work without)
        self._use_flow_api = bool(self.project_id)

        if not self.session_token:
            raise EnvironmentError(
                "GOOGLE_FLOW_SESSION_TOKEN is required. "
                "Extract from Chrome DevTools → Application → Cookies."
            )
        if not self.csrf_token:
            raise EnvironmentError(
                "GOOGLE_FLOW_CSRF_TOKEN is required. "
                "Extract from Chrome DevTools → Application → Cookies."
            )

        if not self._use_flow_api:
            # Fall back to legacy, but warn about model limitations
            if self.model in FLOW_MODELS.values():
                logger.warning(
                    "Model %s requires Flow API (project_id needed). "
                    "Falling back to IMAGEN_3_5 on legacy endpoint. "
                    "Run Auto Setup in settings to enable Flow API.",
                    self.model,
                )
                self.model = "IMAGEN_3_5"
            logger.info("Using LEGACY endpoint (v1:runImageFx) — model=%s", self.model)
        else:
            logger.info(
                "Using FLOW API (flowMedia:batchGenerateImages) — model=%s project=%s recaptcha=%s",
                self.model,
                self.project_id[:12] + "...",
                "yes" if self.recaptcha_token else "no (will try without)",
            )

        self._access_token: str | None = None
        self._token_expiry: datetime | None = None
        self._last_media_id: str | None = None

        self.client = httpx.Client(timeout=90, follow_redirects=True)

    # ------------------------------------------------------------------
    # Cookie helpers
    # ------------------------------------------------------------------

    def _cookie_string(self) -> str:
        return (
            f"__Secure-next-auth.session-token={self.session_token}; "
            f"__Host-next-auth.csrf-token={self.csrf_token}"
        )

    def _update_session_cookie(self, response: httpx.Response) -> None:
        for cookie_header in response.headers.get_list("set-cookie"):
            if "__Secure-next-auth.session-token=" in cookie_header:
                token = cookie_header.split("__Secure-next-auth.session-token=")[1]
                token = token.split(";")[0]
                if token and token != self.session_token:
                    self.session_token = token
                    logger.debug("Session cookie rotated")
                    self._persist_token(token)
                break

    def _persist_token(self, token: str) -> None:
        env_path = Path(__file__).resolve().parent.parent / ".env"
        if not env_path.exists():
            return
        try:
            content = env_path.read_text(encoding="utf-8")
            if "GOOGLE_FLOW_SESSION_TOKEN=" in content:
                lines = content.splitlines()
                new_lines = []
                for line in lines:
                    if line.startswith("GOOGLE_FLOW_SESSION_TOKEN="):
                        new_lines.append(f"GOOGLE_FLOW_SESSION_TOKEN={token}")
                    else:
                        new_lines.append(line)
                env_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
                logger.info("Updated session token in .env")
        except Exception as e:
            logger.warning("Could not persist rotated token: %s", e)

    # ------------------------------------------------------------------
    # Auth
    # ------------------------------------------------------------------

    def _refresh_access_token(self) -> str:
        if (
            self._access_token
            and self._token_expiry
            and datetime.now(timezone.utc) < self._token_expiry
        ):
            return self._access_token

        logger.info("Refreshing access token via session endpoint...")
        resp = self.client.get(
            SESSION_URL,
            headers={**DEFAULT_HEADERS, "Cookie": self._cookie_string()},
        )
        self._update_session_cookie(resp)
        resp.raise_for_status()

        data = resp.json()
        self._access_token = data.get("access_token") or data.get("accessToken")
        expires_str = data.get("expires")

        if not self._access_token:
            raise RuntimeError(
                f"No access_token in session response. Keys: {list(data.keys())}. "
                "Cookies may be expired — re-extract from Chrome."
            )

        if expires_str:
            try:
                self._token_expiry = datetime.fromisoformat(
                    expires_str.replace("Z", "+00:00")
                )
            except ValueError:
                self._token_expiry = None

        logger.info(
            "Access token obtained (expires %s)",
            self._token_expiry.isoformat() if self._token_expiry else "unknown",
        )
        return self._access_token

    def _auth_headers(self) -> dict[str, str]:
        token = self._refresh_access_token()
        return {
            **DEFAULT_HEADERS,
            "Cookie": self._cookie_string(),
            "Authorization": f"Bearer {token}",
        }

    # ------------------------------------------------------------------
    # Flow API generation (newer endpoint)
    # ------------------------------------------------------------------

    def _build_flow_body(self, prompt: str, seed: int | None = None) -> dict:
        """Build payload for flowMedia:batchGenerateImages."""
        if seed is None:
            seed = random.randint(100000, 999999)

        session_id = str(int(time.time() * 1000))

        # Build client context — recaptcha is optional
        client_ctx: dict = {
            "projectId": self.project_id,
            "sessionId": session_id,
            "tool": "PINHOLE",
        }
        if self.recaptcha_token:
            client_ctx["recaptchaContext"] = {
                "applicationType": "RECAPTCHA_APPLICATION_TYPE_WEB",
                "token": self.recaptcha_token,
            }

        return {
            "clientContext": client_ctx,
            "requests": [
                {
                    "clientContext": client_ctx,
                    "imageAspectRatio": self.aspect_ratio,
                    "imageInputs": [],
                    "imageModelName": self.model,
                    "prompt": prompt,
                    "seed": seed,
                }
            ],
        }

    def _generate_flow_sync(
        self, prompt: str, output_path: Path, seed: int | None = None
    ) -> Path:
        """Generate image via the Flow API endpoint."""
        body = self._build_flow_body(prompt, seed=seed)
        url = FLOW_GENERATE_URL.format(project_id=self.project_id)

        logger.info(
            "Generating via Flow API  model=%s  aspect=%s → %s",
            self.model,
            self.aspect_ratio,
            output_path,
        )

        resp = self.client.post(url, json=body, headers=self._auth_headers())
        self._update_session_cookie(resp)
        resp.raise_for_status()

        result = resp.json()
        logger.debug("Flow response keys: %s", list(result.keys()))

        # Flow API returns media[] with fife_url
        media_list = result.get("media", [])
        if not media_list:
            raise RuntimeError(
                f"No media in Flow response. Keys: {list(result.keys())}. "
                f"Response: {json.dumps(result)[:500]}"
            )

        first_media = media_list[0]
        image_data = first_media.get("image", {})
        generated = image_data.get("generatedImage") or image_data.get(
            "generated_image", {}
        )
        fife_url = generated.get("fifeUrl") or generated.get("fife_url", "")
        media_id = generated.get("mediaGenerationId") or generated.get(
            "media_generation_id", ""
        )

        if not fife_url:
            # Try alternate response structures
            logger.warning("No fife_url found, trying alternate paths...")
            # Sometimes the response might have a different structure
            raise RuntimeError(
                f"No fife_url in Flow response. Media keys: {list(first_media.keys())}. "
                f"Image keys: {list(image_data.keys())}. "
                f"Generated keys: {list(generated.keys()) if generated else 'EMPTY'}"
            )

        # Download the image from FIFE URL
        img_url = f"{fife_url}=w1024" if "=" not in fife_url else fife_url
        logger.info("Downloading image from FIFE URL...")
        img_resp = self.client.get(img_url, headers={"Referer": "https://labs.google/"})
        img_resp.raise_for_status()

        with open(output_path, "wb") as f:
            f.write(img_resp.content)

        self._last_media_id = media_id
        logger.info(
            "Image saved: %s (%d bytes) via Flow API",
            output_path,
            len(img_resp.content),
        )
        return output_path

    # ------------------------------------------------------------------
    # Legacy API generation (fallback)
    # ------------------------------------------------------------------

    def _build_legacy_body(
        self, prompt: str, seed: int | None = None, num_images: int = 1
    ) -> dict:
        """Build payload for v1:runImageFx (legacy)."""
        if seed is None:
            seed = random.randint(0, 999999)
        return {
            "userInput": {
                "candidatesCount": num_images,
                "prompts": [prompt],
                "seed": seed,
            },
            "clientContext": {
                "sessionId": f";{int(time.time() * 1000)}",
                "tool": "IMAGE_FX",
            },
            "modelInput": {
                "modelNameType": self.model,
            },
            "aspectRatio": self.aspect_ratio,
        }

    def _generate_legacy_sync(
        self, prompt: str, output_path: Path, seed: int | None = None
    ) -> Path:
        """Generate image via the legacy v1:runImageFx endpoint."""
        body = self._build_legacy_body(prompt, seed=seed)

        logger.info(
            "Generating via Legacy API  model=%s  aspect=%s → %s",
            self.model,
            self.aspect_ratio,
            output_path,
        )

        resp = self.client.post(
            LEGACY_GENERATE_URL, json=body, headers=self._auth_headers()
        )
        self._update_session_cookie(resp)
        resp.raise_for_status()

        result = resp.json()
        panels = result.get("imagePanels") or []
        if not panels:
            raise RuntimeError(
                f"No imagePanels in response. Full: {json.dumps(result)[:500]}"
            )

        images = panels[0].get("generatedImages") or []
        if not images:
            raise RuntimeError(
                f"No generatedImages. Panel keys: {list(panels[0].keys())}"
            )

        first_image = images[0]
        encoded = first_image.get("encodedImage")
        media_id = first_image.get("mediaGenerationId")

        if not encoded:
            raise RuntimeError(f"No encodedImage. Keys: {list(first_image.keys())}")

        img_data = base64.b64decode(encoded)
        with open(output_path, "wb") as f:
            f.write(img_data)

        self._last_media_id = media_id
        logger.info(
            "Image saved: %s (%d bytes) via Legacy API", output_path, len(img_data)
        )
        return output_path

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def _generate_sync(
        self,
        prompt: str,
        output_path: str | Path,
        seed: int | None = None,
        num_images: int = 1,
    ) -> Path:
        """Generate an image. Automatically selects Flow API or Legacy based on credentials."""
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        if self._use_flow_api:
            try:
                return self._generate_flow_sync(prompt, output_path, seed=seed)
            except Exception as e:
                logger.error("Flow API failed: %s. Trying legacy fallback...", e)
                # Fall back to legacy with safe model
                old_model = self.model
                self.model = "IMAGEN_3_5"
                try:
                    return self._generate_legacy_sync(prompt, output_path, seed=seed)
                finally:
                    self.model = old_model
        else:
            return self._generate_legacy_sync(prompt, output_path, seed=seed)

    async def generate(
        self,
        prompt: str,
        output_path: Path,
        **kwargs: Any,
    ) -> Path:
        """Generate an image asynchronously."""
        return await asyncio.to_thread(
            self._generate_sync, prompt, output_path, **kwargs
        )

    def generate_with_media_id(
        self,
        prompt: str,
        output_path: str | Path,
        seed: int | None = None,
        num_images: int = 1,
    ) -> tuple[Path, str]:
        """Generate an image and return both the path and the mediaGenerationId."""
        path = self._generate_sync(
            prompt, output_path, seed=seed, num_images=num_images
        )
        return path, self._last_media_id

    def close(self):
        self.client.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()
