"""
Pipeline orchestrator – runs all 5 agents in sequence and assembles the
ProductionPlan.

Flow:
  1. ScriptAnalyzer   → narrative beats
  2. FrameExpander    → 25 frames/min with camera variety
  3. PromptEngineer   → image + video prompt per frame (one LLM call each)
  4. TTSScriptwriter  → narration text per frame
  5. TimelinePlanner  → exact timestamps and cut points
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Callable, Optional
from concurrent.futures import ThreadPoolExecutor

from rich.console import Console
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
)

from models import (
    AllScenePrompts,
    AllSceneTTS,
    ExpandedFrameBreakdown,
    PipelineConfig,
    PipelineEvent,
    PipelineMode,
    ProductionPlan,
    ProductionScene,
    SceneBreakdown,
    SceneTTS,
    Timeline,
)
from agents.script_analyzer import ScriptAnalyzerAgent
from agents.frame_expander import FrameExpanderAgent
from agents.prompt_engineer import PromptEngineerAgent
from agents.tts_scriptwriter import TTSScriptwriterAgent
from agents.timeline_planner import TimelinePlannerAgent

logger = logging.getLogger(__name__)
# Force UTF-8 encoding for console output to avoid charmap codec errors
console = Console(force_terminal=True, soft_wrap=True)


# ---------------------------------------------------------------------------
# Merge helpers
# ---------------------------------------------------------------------------


def _merge_into_production_scenes(
    frames: ExpandedFrameBreakdown,
    prompts: AllScenePrompts,
    tts: AllSceneTTS,
    timeline: Timeline,
) -> list[ProductionScene]:
    """Combine outputs from all agents into unified ProductionScene list."""
    scenes: list[ProductionScene] = []

    prompt_map = {p.scene_number: p for p in prompts.scenes}
    tts_map = {t.scene_number: t for t in tts.scenes}
    time_map = {t.scene_number: t for t in timeline.scenes}

    for frame in frames.frames:
        n = frame.frame_number
        p = prompt_map.get(n)
        t = tts_map.get(n)
        tm = time_map.get(n)

        scenes.append(
            ProductionScene(
                scene_number=n,
                time_start=tm.time_start if tm else "00:00.00",
                time_end=tm.time_end if tm else "00:00.00",
                duration_seconds=tm.duration_seconds if tm else 0.0,
                cut_point=tm.cut_point if tm else "00:00.00",
                transition=tm.transition if tm else "crossfade 0.3s",
                environment=p.environment
                if p and p.environment
                else frame.environment_detail,
                pose_action=frame.pose_action,
                mood=frame.mood,
                camera_shot=frame.camera_shot,
                camera_movement=frame.camera_movement,
                composition=frame.composition,
                visual_focus=frame.visual_focus,
                image_prompt=p.image_prompt if p else "",
                video_prompt=p.video_prompt if p else "",
                tts_text=t.tts_text if t else frame.narration_text,
                voice_direction=t.voice_direction if t else "",
                narration_notes=frame.narration_notes,
                sync_word=p.sync_word if p else frame.sync_word,
                anatomical_highlight=p.anatomical_highlight
                if p
                else frame.anatomical_highlight,
                clothing_description=p.clothing_description if p else "",
                emotional_tone=p.emotional_tone if p else frame.mood,
                # Coverage metadata from Frame
                coverage_group_id=getattr(frame, "coverage_group_id", None),
                coverage_angle_index=getattr(frame, "coverage_angle_index", 0),
                coverage_angle_type=getattr(frame, "coverage_angle_type", ""),
            )
        )
    return scenes


def _build_assembly_instructions(plan: ProductionPlan) -> list[str]:
    """Generate step-by-step assembly instructions."""
    instructions = [
        "=== ASSEMBLY INSTRUCTIONS ===",
        "",
        f"Total frames: {plan.total_scenes}",
        f"Total duration: {plan.total_duration_formatted}",
        f"Frames per minute: ~{round(plan.total_scenes / max(plan.total_duration_seconds / 60, 0.1))}",
        "",
        "STEP 1 — GENERATE IMAGES",
        "For each frame below, copy the 'Image Prompt' and paste it into your "
        "image generator (DALL-E 3, Midjourney, Flux, etc.).",
        "Save each generated image as frame_XX.png.",
        "",
        "STEP 2 — GENERATE VIDEO CLIPS",
        "For each frame, upload the generated image to your video generator "
        "(Runway ML, Kling AI, Luma, etc.) and paste the 'Video Prompt'.",
        "Save each video clip as frame_XX.mp4.",
        "",
        "STEP 3 — GENERATE VOICEOVER",
        "For each frame, copy the 'TTS Text' and paste it into ElevenLabs.",
        f"Use Voice ID: {plan.voice_settings.voice_id}",
        f"Use Model: {plan.voice_settings.model_id}",
        f"Stability: {plan.voice_settings.stability} | "
        f"Similarity: {plan.voice_settings.similarity_boost} | "
        f"Speed: {plan.voice_settings.speed}",
        "Save each audio as frame_XX.mp3.",
        "",
        "STEP 4 — ASSEMBLE IN VIDEO EDITOR",
        f"Create a new project: {plan.export_settings.resolution} @ "
        f"{plan.export_settings.fps}fps ({plan.export_settings.aspect_ratio})",
    ]

    for s in plan.scenes:
        instructions.append(
            f"  [{s.camera_shot}] Place frame_{s.scene_number:02d}.mp4 "
            f"at {s.time_start}"
        )
        instructions.append(
            f"    Overlay frame_{s.scene_number:02d}.mp3 audio at {s.time_start}"
        )
        instructions.append(f"    Cut at {s.cut_point} | {s.transition}")

    instructions += [
        "",
        "STEP 5 — EXPORT",
        f"Export as {plan.export_settings.video_format.upper()} "
        f"({plan.export_settings.resolution}, {plan.export_settings.fps}fps)",
        f"Total duration: {plan.total_duration_formatted}",
        "",
        "STEP 6 — PUBLISH",
        f"Upload to {plan.target_platform} and add captions/hashtags.",
    ]
    return instructions


# ---------------------------------------------------------------------------
# Pipeline
# ---------------------------------------------------------------------------


class Pipeline:
    """Orchestrates the full video production pipeline (5 agents)."""

    def __init__(
        self,
        config: PipelineConfig,
        on_event: Optional[Callable[[PipelineEvent], None]] = None,
    ):
        self.config = config
        self.model = config.llm_model
        self._on_event = on_event

    def _emit(self, **kwargs) -> None:
        """Emit a pipeline event if a callback is registered."""
        if self._on_event is not None:
            self._on_event(PipelineEvent(**kwargs))

    def run(self, story: str) -> ProductionPlan:
        """Execute the pipeline and return a ProductionPlan."""
        cfg = self.config

        try:
            return self._run_pipeline(story, cfg)
        except Exception as exc:
            self._emit(event_type="error", message=str(exc))
            raise

    def _run_pipeline(self, story: str, cfg: PipelineConfig) -> ProductionPlan:
        """Internal pipeline execution with event emission."""

        # ── Agent 1: Script Analyzer ──────────────────────────────────
        self._emit(
            event_type="agent_start",
            agent_number=1,
            agent_name="Script Analyzer",
            message="Analyzing script into narrative beats...",
        )
        with Progress(
            SpinnerColumn(),
            TextColumn("[bold blue]{task.description}"),
            console=console,
        ) as progress:
            task = progress.add_task(
                "Agent 1/5: Analyzing script into narrative beats...", total=None
            )
            analyzer = ScriptAnalyzerAgent(model=self.model)
            breakdown: SceneBreakdown = analyzer.analyze(story, cfg.target_platform)
            progress.update(task, description="[green]Agent 1/5: Script analyzed")
        console.print(
            f"  [dim]-> {breakdown.total_scenes} narrative beats identified[/dim]"
        )
        self._emit(
            event_type="agent_complete",
            agent_number=1,
            agent_name="Script Analyzer",
            message=f"{breakdown.total_scenes} narrative beats identified",
        )

        # ── Agent 2: Frame Expander ───────────────────────────────────
        self._emit(
            event_type="agent_start",
            agent_number=2,
            agent_name="Frame Expander",
            message=f"Expanding beats into ~{cfg.frames_per_minute} cinematic frames...",
        )
        with Progress(
            SpinnerColumn(),
            TextColumn("[bold blue]{task.description}"),
            console=console,
        ) as progress:
            task = progress.add_task(
                f"Agent 2/5: Expanding beats into ~{cfg.frames_per_minute} "
                f"cinematic frames...",
                total=None,
            )
            expander = FrameExpanderAgent(model=self.model)
            frames: ExpandedFrameBreakdown = expander.expand(breakdown, cfg)
            progress.update(task, description="[green]Agent 2/5: Frames expanded")
        console.print(
            f"  [dim]-> {frames.total_frames} frames with varied camera angles[/dim]"
        )
        self._emit(
            event_type="agent_complete",
            agent_number=2,
            agent_name="Frame Expander",
            message=f"{frames.total_frames} frames with varied camera angles",
        )

        # ── Agents 3 & 4: Run in PARALLEL ─────────────────────────────
        self._emit(
            event_type="agent_start",
            agent_number=3,
            agent_name="Prompt Engineer",
            message=f"Engineering prompts ({frames.total_frames} frames, streaming)...",
            progress_current=0,
            progress_total=frames.total_frames,
        )
        self._emit(
            event_type="agent_start",
            agent_number=4,
            agent_name="TTS Scriptwriter",
            message="Writing narration scripts...",
        )
        console.print(
            f"\n[bold blue]Agents 3+4/5: Running Prompt Engineer + TTS Scriptwriter "
            f"in parallel...[/bold blue]"
        )

        def _run_prompt_engineer() -> AllScenePrompts:
            prompter = PromptEngineerAgent(model=self.model)
            return prompter.generate(
                frames,
                cfg.character_template,
                world_style=cfg.world_style,
                clothing_style=cfg.clothing_style,
                production_rules=cfg.production_rules,
                product_placement=cfg.product_placement,
            )

        def _run_tts_scriptwriter() -> AllSceneTTS:
            tts_writer = TTSScriptwriterAgent(model=self.model)
            from models import SceneAnalysis

            # Filter to master frames for TTS (avoid duplicating TTS for coverage angles)
            coverage_enabled = any(
                getattr(f, 'coverage_group_id', None) for f in frames.frames
            )
            tts_source_frames = [
                f for f in frames.frames
                if not coverage_enabled or getattr(f, 'coverage_angle_index', 0) == 0
            ]

            frame_as_scenes = SceneBreakdown(
                title=frames.title,
                summary=frames.summary,
                total_scenes=len(tts_source_frames),
                scenes=[
                    SceneAnalysis(
                        scene_number=f.frame_number,
                        environment=f.environment_detail,
                        pose_action=f.pose_action,
                        mood=f.mood,
                        narration_text=f.narration_text,
                        narration_notes=f.narration_notes,
                    )
                    for f in tts_source_frames
                ],
            )
            return tts_writer.generate(frame_as_scenes)

        with ThreadPoolExecutor(max_workers=2) as executor:
            prompt_future = executor.submit(_run_prompt_engineer)
            tts_future = executor.submit(_run_tts_scriptwriter)

            prompts: AllScenePrompts = prompt_future.result()
            tts: AllSceneTTS = tts_future.result()

        # Replicate TTS for secondary coverage angles
        coverage_enabled = any(
            getattr(f, 'coverage_group_id', None) for f in frames.frames
        )
        if coverage_enabled:
            # Build lookup: coverage_group_id → master's SceneTTS
            master_frame_map = {
                f.frame_number: f for f in frames.frames
                if getattr(f, 'coverage_angle_index', 0) == 0
                and getattr(f, 'coverage_group_id', None)
            }
            tts_by_group: dict[str, SceneTTS] = {}
            for scene_tts in tts.scenes:
                master_frame = master_frame_map.get(scene_tts.scene_number)
                if master_frame and master_frame.coverage_group_id:
                    tts_by_group[master_frame.coverage_group_id] = scene_tts

            # Create TTS entries for secondary angles
            secondary_tts = []
            for f in frames.frames:
                if getattr(f, 'coverage_angle_index', 0) > 0:
                    group_id = getattr(f, 'coverage_group_id', None)
                    if group_id and group_id in tts_by_group:
                        src = tts_by_group[group_id]
                        secondary_tts.append(SceneTTS(
                            scene_number=f.frame_number,
                            tts_text=src.tts_text,
                            voice_direction=src.voice_direction,
                        ))
            tts.scenes.extend(secondary_tts)
            tts.scenes.sort(key=lambda s: s.scene_number)

        console.print(
            f"  [dim]-> {len(prompts.scenes)} image + video prompts generated[/dim]"
        )
        console.print(f"  [dim]-> {len(tts.scenes)} TTS scripts[/dim]")
        self._emit(
            event_type="agent_complete",
            agent_number=3,
            agent_name="Prompt Engineer",
            message=f"{len(prompts.scenes)} image + video prompts generated",
        )
        self._emit(
            event_type="agent_complete",
            agent_number=4,
            agent_name="TTS Scriptwriter",
            message=f"{len(tts.scenes)} TTS scripts generated",
        )
        # ── Agent 5: Timeline Planner ─────────────────────────────────
        self._emit(
            event_type="agent_start",
            agent_number=5,
            agent_name="Timeline Planner",
            message="Planning timeline...",
        )
        with Progress(
            SpinnerColumn(),
            TextColumn("[bold blue]{task.description}"),
            console=console,
        ) as progress:
            task = progress.add_task("Agent 5/5: Planning timeline...", total=None)
            planner = TimelinePlannerAgent(config=cfg)
            timeline: Timeline = planner.plan(tts)
            progress.update(task, description="[green]Agent 5/5: Timeline calculated")
        console.print(
            f"  [dim]-> Total duration: {timeline.total_duration_formatted}[/dim]"
        )
        self._emit(
            event_type="agent_complete",
            agent_number=5,
            agent_name="Timeline Planner",
            message=f"Total duration: {timeline.total_duration_formatted}",
        )

        # ── Merge everything ──────────────────────────────────────────
        production_scenes = _merge_into_production_scenes(
            frames, prompts, tts, timeline
        )

        plan = ProductionPlan(
            title=frames.title,
            summary=frames.summary,
            total_duration_formatted=timeline.total_duration_formatted,
            total_duration_seconds=timeline.total_duration_seconds,
            total_scenes=frames.total_frames,
            target_platform=cfg.target_platform.value,
            character_template_name=cfg.character_template.name,
            voice_settings=cfg.voice_settings,
            export_settings=cfg.export_settings,
            scenes=production_scenes,
        )

        plan.assembly_instructions = _build_assembly_instructions(plan)

        # ── Optional: Generate assets based on mode ────────────────────
        if cfg.mode == PipelineMode.PLAN_AND_IMAGES:
            self._generate_images(plan)
        elif cfg.mode == PipelineMode.GENERATE:
            self._generate_all_assets(plan)

        self._emit(
            event_type="pipeline_complete",
            message=f"Pipeline complete: {plan.title} ({plan.total_scenes} scenes, {plan.total_duration_formatted})",
        )

        return plan

    def _generate_images(self, plan: ProductionPlan) -> None:
        """Generate images in batches using Flow HTTP Engine."""
        output_dir = Path(self.config.output_dir) / "assets"
        output_dir.mkdir(parents=True, exist_ok=True)

        BATCH_SIZE = 10  # Flow supports up to 12, using 10 for safety

        console.print(
            f"\n[bold yellow]Generating images with Flow HTTP Engine (batches of {BATCH_SIZE})...[/bold yellow]"
        )

        try:
            import asyncio
            from providers.flow_http_engine import FlowHttpEngine, ASPECT_RATIOS
            import os

            # Get cookies from environment
            cookies_json = os.getenv("GOOGLE_FLOW_COOKIES", "")
            
            if not cookies_json:
                # Fallback: build from individual tokens
                session_token = os.getenv("GOOGLE_FLOW_SESSION_TOKEN", "")
                csrf_token = os.getenv("GOOGLE_FLOW_CSRF_TOKEN", "")
                
                if session_token and csrf_token:
                    import json
                    cookies_json = json.dumps([
                        {"name": "__Secure-next-auth.session-token", "value": session_token},
                        {"name": "__Host-next-auth.csrf-token", "value": csrf_token},
                        {"name": "__Secure-next-auth.callback-url", "value": "https%3A%2F%2Flabs.google%2Ffx%2Fpt%2Ftools%2Fflow"},
                    ])
                else:
                    raise RuntimeError("GOOGLE_FLOW_COOKIES or session tokens required")

            aspect = ASPECT_RATIOS.get(
                os.getenv("GOOGLE_FLOW_ASPECT_RATIO", "portrait").lower(),
                "IMAGE_ASPECT_RATIO_PORTRAIT",
            )

            engine = FlowHttpEngine(cookies_json)
            loop = asyncio.new_event_loop()
            try:
                ok = loop.run_until_complete(engine.start())
                if not ok:
                    raise RuntimeError("FlowHttpEngine failed to start")
                console.print(
                    f"[dim]Using Flow HTTP Engine — model={engine.model}[/dim]"
                )

                # Separate scenes into coverage groups and regular scenes
                coverage_scenes = []
                regular_scenes = []
                
                for s in plan.scenes:
                    coverage_group = getattr(s, "coverage_group_id", None)
                    if coverage_group:
                        coverage_scenes.append(s)
                    else:
                        regular_scenes.append(s)

                # ── Process regular scenes in batches ─────────────────────────
                if regular_scenes:
                    total_batches = (len(regular_scenes) + BATCH_SIZE - 1) // BATCH_SIZE
                    console.print(
                        f"[dim]Processing {len(regular_scenes)} regular scenes in {total_batches} batch(es)[/dim]"
                    )
                    
                    for batch_num in range(total_batches):
                        start_idx = batch_num * BATCH_SIZE
                        end_idx = min(start_idx + BATCH_SIZE, len(regular_scenes))
                        batch_scenes = regular_scenes[start_idx:end_idx]
                        
                        console.print(
                            f"\n[bold blue]Batch {batch_num + 1}/{total_batches} "
                            f"({len(batch_scenes)} images)...[/bold blue]"
                        )
                        
                        # Prepare batch items
                        batch_items = []
                        for s in batch_scenes:
                            img_path = output_dir / f"frame_{s.scene_number:02d}.png"
                            batch_items.append((s.image_prompt, img_path))
                        
                        # Generate batch
                        def on_progress(current, total):
                            scene = batch_scenes[current - 1]
                            console.print(
                                f"[green]✓ Frame {scene.scene_number}/{plan.total_scenes}[/green]"
                            )
                            self._emit(
                                event_type="image_generated",
                                image_scene_number=scene.scene_number,
                                image_base64=None,
                                image_path=str(output_dir / f"frame_{scene.scene_number:02d}.png"),
                                progress_current=scene.scene_number,
                                progress_total=plan.total_scenes,
                                message=f"Image {scene.scene_number}/{plan.total_scenes} generated",
                            )
                        
                        results = loop.run_until_complete(
                            engine.generate_batch(
                                batch_items,
                                aspect_ratio=aspect,
                                on_progress=on_progress,
                            )
                        )
                        
                        successful = sum(1 for r in results if r is not None)
                        console.print(
                            f"[dim]Batch {batch_num + 1} complete: {successful}/{len(batch_scenes)} successful[/dim]"
                        )

                # ── Process coverage scenes sequentially (need reference IDs) ─
                if coverage_scenes:
                    console.print(
                        f"\n[dim]Processing {len(coverage_scenes)} coverage scenes sequentially[/dim]"
                    )
                    
                    for s in coverage_scenes:
                        img_path = output_dir / f"frame_{s.scene_number:02d}.png"
                        
                        loop.run_until_complete(
                            engine.generate(
                                s.image_prompt, img_path, aspect_ratio=aspect
                            )
                        )

                        console.print(
                            f"[green]✓ Frame {s.scene_number}/{plan.total_scenes}[/green]"
                        )

                        self._emit(
                            event_type="image_generated",
                            image_scene_number=s.scene_number,
                            image_base64=None,
                            image_path=str(img_path),
                            progress_current=s.scene_number,
                            progress_total=plan.total_scenes,
                            message=f"Image {s.scene_number}/{plan.total_scenes} generated",
                        )

            finally:
                loop.run_until_complete(engine.stop())
                loop.close()
            console.print("\n[bold green]All images generated![/bold green]")
        except Exception as exc:
            console.print(f"[red]Image generation failed: {exc}[/red]")
            raise

    def _generate_all_assets(self, plan: ProductionPlan) -> None:
        """Generate all assets: images, TTS, video (for GENERATE mode)."""
        output_dir = Path(self.config.output_dir) / "assets"
        output_dir.mkdir(parents=True, exist_ok=True)

        console.print(
            "\n[bold yellow]Generating all assets (this may take a while)..."
            "[/bold yellow]"
        )

        # Images first
        self._generate_images(plan)

        # TTS (using factory to support ElevenLabs or Pollinations)
        try:
            from providers import get_tts_provider

            provider_name = getattr(
                self.config.voice_settings, "tts_provider", "elevenlabs"
            )
            console.print(
                f"\n[bold yellow]Generating TTS audio via {provider_name.upper()}...[/bold yellow]"
            )

            tts_prov = get_tts_provider(self.config.voice_settings)

            for s in plan.scenes:
                # Use .mp3 for all providers
                ext = ".mp3"
                audio_path = output_dir / f"frame_{s.scene_number:02d}{ext}"
                tts_prov.generate(s.tts_text, audio_path)
                console.print(
                    f"[green]✓ Audio {s.scene_number}/{plan.total_scenes} ({provider_name})[/green]"
                )

            console.print(
                f"[bold green]All TTS audio generated via {provider_name.upper()}![/bold green]"
            )
        except Exception as exc:
            console.print(f"[red]TTS generation failed: {exc}[/red]")
            logger.exception("TTS generation error")

        # Video clips via Grok I2V
        try:
            from providers.grok_video_provider import GrokVideoProvider

            console.print(
                "\n[bold yellow]Generating video clips via Grok I2V...[/bold yellow]"
            )
            self._emit(
                event_type="agent_start",
                agent_name="Video Generator",
                message="Generating video clips via Grok I2V...",
                progress_current=0,
                progress_total=plan.total_scenes,
            )

            video_prov = GrokVideoProvider()

            for s in plan.scenes:
                img_path = output_dir / f"frame_{s.scene_number:02d}.png"
                vid_path = output_dir / f"frame_{s.scene_number:02d}.mp4"

                if not img_path.exists():
                    logger.warning(
                        "Image not found for scene %d, skipping video", s.scene_number
                    )
                    continue

                def _on_video_progress(progress: int, _sn=s.scene_number) -> None:
                    self._emit(
                        event_type="agent_progress",
                        agent_name="Video Generator",
                        message=f"Scene {_sn}: {progress}%",
                        progress_current=_sn,
                        progress_total=plan.total_scenes,
                    )

                try:
                    video_prov.generate(
                        image_path=img_path,
                        video_prompt=s.video_prompt,
                        output_path=vid_path,
                        on_progress=_on_video_progress,
                    )
                    console.print(
                        f"[green]✓ Video {s.scene_number}/{plan.total_scenes}[/green]"
                    )
                except Exception as vid_exc:
                    logger.error(
                        "Video failed for scene %d: %s", s.scene_number, vid_exc
                    )
                    console.print(
                        f"[red]✗ Video {s.scene_number} failed: {vid_exc}[/red]"
                    )

            video_prov.close()
            console.print("[bold green]Video generation complete![/bold green]")
            self._emit(
                event_type="agent_complete",
                agent_name="Video Generator",
                message=f"Video generation complete ({plan.total_scenes} clips)",
            )

        except EnvironmentError as env_exc:
            console.print(f"\n[dim]Video generation skipped: {env_exc}[/dim]")
        except Exception as exc:
            console.print(f"[red]Video generation failed: {exc}[/red]")
            logger.exception("Video generation failed")
