"""
ASS Subtitle Engine — Production-quality subtitle generation.

Generates Advanced SubStation Alpha (.ass) subtitle files with:
- Word-by-word karaoke highlighting (\\kf fill-sweep)
- Smart word grouping (pause-based, time-based, fixed-count)
- 9-position alignment (\\an1–\\an9)
- Fade/pop animations
- 5 built-in style presets (tiktok_karaoke, minimal, bold_center, neon, classic_bottom)
- Duration clipping to prevent subtitle overflow
- Per-scene positioning support

Replaces the old drawtext filter chain with a single FFmpeg `ass=` filter.
"""

from __future__ import annotations

import logging
import re
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Font paths — bundled Montserrat + system fallback
# ---------------------------------------------------------------------------
_FONTS_DIR = Path(__file__).resolve().parent / "fonts"
MONTSERRAT_FONT = _FONTS_DIR / "Montserrat.ttf"
FALLBACK_FONT = Path("/usr/share/fonts/TTF/DejaVuSans-Bold.ttf")


def _resolve_font(name: str) -> str:
    """Return a font family name for ASS (libass resolves via fontconfig)."""
    mapping = {
        "montserrat": "Montserrat",
        "montserrat black": "Montserrat",
        "dejavu": "DejaVu Sans",
        "dejavu sans": "DejaVu Sans",
        "arial": "Arial",
        "inter": "Inter",
        "sans": "Sans",
    }
    return mapping.get(name.lower(), name)


# ---------------------------------------------------------------------------
# Color helpers  (ASS uses &HAABBGGRR format)
# ---------------------------------------------------------------------------


def _hex_to_ass(hex_color: str, alpha: int = 0) -> str:
    """Convert #RRGGBB or #AARRGGBB to ASS &HAABBGGRR format."""
    h = hex_color.lstrip("#")
    if len(h) == 8:
        a, r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16), int(h[6:8], 16)
    elif len(h) == 6:
        a = alpha
        r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
    else:
        return "&H00FFFFFF"
    return f"&H{a:02X}{b:02X}{g:02X}{r:02X}"


def _rgba_to_ass(rgba: str) -> str:
    """Convert rgba(r,g,b,a) to ASS &HAABBGGRR."""
    m = re.match(
        r"rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*(?:,\s*([\d.]+))?\s*\)", rgba
    )
    if not m:
        return "&H00000000"
    r, g, b = int(m.group(1)), int(m.group(2)), int(m.group(3))
    a = float(m.group(4)) if m.group(4) else 1.0
    alpha = int((1.0 - a) * 255)  # ASS: 0=opaque, 255=transparent
    return f"&H{alpha:02X}{b:02X}{g:02X}{r:02X}"


def _parse_color(val: str | None, default: str = "&H00FFFFFF") -> str:
    """Accept ASS, hex, or rgba color and return ASS format."""
    if not val:
        return default
    if val.startswith("&H"):
        return val
    if val.startswith("#"):
        return _hex_to_ass(val)
    if val.startswith("rgb"):
        return _rgba_to_ass(val)
    return default


# ---------------------------------------------------------------------------
# Style presets
# ---------------------------------------------------------------------------

SUBTITLE_PRESETS: dict[str, dict[str, Any]] = {
    "tiktok_karaoke": {
        "fontFamily": "Montserrat",
        "fontSize": 90,
        "primaryColor": "&H0000FFFF",  # Yellow (highlighted word)
        "secondaryColor": "&H00FFFFFF",  # White (unhighlighted)
        "outlineColor": "&H00000000",  # Black outline
        "backColor": "&H99000000",  # Semi-transparent black box
        "bold": True,
        "borderStyle": 3,  # Box background
        "outline": 2,
        "shadow": 0,
        "alignment": 2,  # \an2 bottom-center
        "marginV": 140,
        "marginL": 60,
        "marginR": 60,
        "animation": "fade",
        "fadeDuration": 80,  # ms
        "karaoke": True,
        "karaokeMode": "fill",  # \kf (fill sweep)
        "grouping": "pause",
        "groupSize": 4,
        "maxGroupDuration": 1.5,
        "pauseThreshold": 0.3,
    },
    "minimal": {
        "fontFamily": "DejaVu Sans",
        "fontSize": 72,
        "primaryColor": "&H00FFFFFF",
        "secondaryColor": "&H00AAAAAA",
        "outlineColor": "&H00000000",
        "backColor": "&H00000000",  # No background
        "bold": False,
        "borderStyle": 1,  # Outline only
        "outline": 4,
        "shadow": 1,
        "alignment": 2,
        "marginV": 120,
        "marginL": 60,
        "marginR": 60,
        "animation": "fade",
        "fadeDuration": 150,
        "karaoke": False,
        "grouping": "time",
        "groupSize": 5,
        "maxGroupDuration": 2.0,
        "pauseThreshold": 0.3,
    },
    "bold_center": {
        "fontFamily": "Montserrat",
        "fontSize": 110,
        "primaryColor": "&H00FFFFFF",
        "secondaryColor": "&H0080FFFF",  # Lighter yellow
        "outlineColor": "&H00000000",
        "backColor": "&H00000000",
        "bold": True,
        "borderStyle": 1,
        "outline": 5,
        "shadow": 2,
        "alignment": 5,  # \an5 center
        "marginV": 0,
        "marginL": 80,
        "marginR": 80,
        "animation": "pop",
        "fadeDuration": 0,
        "karaoke": True,
        "karaokeMode": "fill",
        "grouping": "fixed",
        "groupSize": 3,
        "maxGroupDuration": 1.5,
        "pauseThreshold": 0.3,
    },
    "neon": {
        "fontFamily": "Montserrat",
        "fontSize": 85,
        "primaryColor": "&H00FF00FF",  # Magenta
        "secondaryColor": "&H00FFFF00",  # Cyan
        "outlineColor": "&H00FF00FF",
        "backColor": "&H00000000",
        "bold": True,
        "borderStyle": 1,
        "outline": 4,
        "shadow": 0,
        "alignment": 2,
        "marginV": 140,
        "marginL": 60,
        "marginR": 60,
        "animation": "fade",
        "fadeDuration": 100,
        "karaoke": True,
        "karaokeMode": "fill",
        "grouping": "pause",
        "groupSize": 4,
        "maxGroupDuration": 1.5,
        "pauseThreshold": 0.3,
    },
    "classic_bottom": {
        "fontFamily": "DejaVu Sans",
        "fontSize": 70,
        "primaryColor": "&H00FFFFFF",
        "secondaryColor": "&H00FFFFFF",
        "outlineColor": "&H00000000",
        "backColor": "&H80000000",
        "bold": False,
        "borderStyle": 3,
        "outline": 1,
        "shadow": 0,
        "alignment": 2,
        "marginV": 100,
        "marginL": 60,
        "marginR": 60,
        "animation": "fade",
        "fadeDuration": 200,
        "karaoke": False,
        "grouping": "time",
        "groupSize": 5,
        "maxGroupDuration": 2.5,
        "pauseThreshold": 0.3,
    },
    "single_word_pop": {
        "fontFamily": "Montserrat",
        "fontSize": 120,
        "primaryColor": "&H00FFFFFF",  # White
        "secondaryColor": "&H00FFFFFF",
        "outlineColor": "&H00000000",  # Black outline
        "backColor": "&H00000000",     # No background
        "bold": True,
        "borderStyle": 1,              # Outline only
        "outline": 8,                  # Thick outline for readability
        "shadow": 3,
        "alignment": 2,                # \an2 bottom-center (below time markers)
        "marginV": 200,               # Well above TikTok UI chrome
        "marginL": 60,
        "marginR": 60,
        "animation": "slam",
        "fadeDuration": 0,
        "karaoke": False,
        "grouping": "single_word",
        "groupSize": 1,
        "maxGroupDuration": 99,
        "pauseThreshold": 99,
    },
    "time_marker": {
        "fontFamily": "Montserrat",
        "fontSize": 140,
        "primaryColor": "&H00FFFFFF",  # White
        "secondaryColor": "&H00FFFFFF",
        "outlineColor": "&H00000000",  # Black outline
        "backColor": "&H00000000",
        "bold": True,
        "borderStyle": 1,
        "outline": 6,
        "shadow": 3,
        "alignment": 5,                # \an5 center
        "marginV": 0,
        "marginL": 0,
        "marginR": 0,
        "animation": "slam",
        "fadeDuration": 0,
        "karaoke": False,
        "grouping": "fixed",
        "groupSize": 99,               # All words in one group for markers
        "maxGroupDuration": 99,
        "pauseThreshold": 99,
    },
}


def list_presets() -> list[dict[str, Any]]:
    """Return preset names with descriptions for the API."""
    descriptions = {
        "tiktok_karaoke": "TikTok-style karaoke with yellow word highlighting on black box",
        "minimal": "Clean white text with subtle outline, no background",
        "bold_center": "Large bold centered text with karaoke highlighting",
        "neon": "Vibrant magenta/cyan neon-style karaoke subtitles",
        "classic_bottom": "Traditional bottom subtitles with semi-transparent box",
        "single_word_pop": "One word at a time, centered, with slam-in animation",
        "time_marker": "Large centered text for time markers (Day 1, Week 2) with slam effect",
    }
    return [
        {"name": k, "description": descriptions.get(k, ""), "config": v}
        for k, v in SUBTITLE_PRESETS.items()
    ]


# ---------------------------------------------------------------------------
# Word grouping strategies
# ---------------------------------------------------------------------------


def _group_fixed(words: list[dict], size: int = 4) -> list[list[dict]]:
    """Group words into fixed-count chunks."""
    groups: list[list[dict]] = []
    for i in range(0, len(words), size):
        groups.append(words[i : i + size])
    return groups


def _group_by_time(
    words: list[dict],
    max_duration: float = 1.5,
    max_words: int = 5,
) -> list[list[dict]]:
    """Group words until accumulated duration hits max or word cap."""
    groups: list[list[dict]] = []
    current: list[dict] = []
    start_t = 0.0
    for w in words:
        if not current:
            start_t = w["start"]
        current.append(w)
        elapsed = w["end"] - start_t
        if elapsed >= max_duration or len(current) >= max_words:
            groups.append(current)
            current = []
    if current:
        groups.append(current)
    return groups


def _group_by_pause(
    words: list[dict],
    threshold: float = 0.3,
    max_words: int = 5,
) -> list[list[dict]]:
    """Split at natural speech pauses (gap > threshold between words)."""
    if not words:
        return []
    groups: list[list[dict]] = []
    current: list[dict] = [words[0]]
    for i in range(1, len(words)):
        gap = words[i]["start"] - words[i - 1]["end"]
        if gap > threshold or len(current) >= max_words:
            groups.append(current)
            current = []
        current.append(words[i])
    if current:
        groups.append(current)
    return groups




def _group_single_word(words: list[dict]) -> list[list[dict]]:
    """Each word is its own group — one word at a time display."""
    return [[w] for w in words]

def group_words(
    words: list[dict],
    strategy: str = "pause",
    group_size: int = 4,
    max_duration: float = 1.5,
    pause_threshold: float = 0.3,
) -> list[list[dict]]:
    """Group word-level timing data into phrase blocks."""
    if strategy == "single_word":
        return _group_single_word(words)
    elif strategy == "fixed":
        return _group_fixed(words, group_size)
    elif strategy == "time":
        return _group_by_time(words, max_duration, group_size)
    else:  # "pause" (default)
        return _group_by_pause(words, pause_threshold, group_size)


# ---------------------------------------------------------------------------
# ASS time formatting
# ---------------------------------------------------------------------------


def _fmt_ass_time(seconds: float) -> str:
    """Format seconds as H:MM:SS.cc (ASS timestamp format — centiseconds)."""
    if seconds < 0:
        seconds = 0.0
    h = int(seconds // 3600)
    m = int((seconds % 3600) // 60)
    s = seconds % 60
    return f"{h}:{m:02d}:{s:05.2f}"


# ---------------------------------------------------------------------------
# ASS escape
# ---------------------------------------------------------------------------


def _escape_ass(text: str) -> str:
    """Escape special characters for ASS dialogue text."""
    # ASS uses \N for newline, \n for soft wrap, \h for hard space
    # Backslashes that aren't ASS tags need escaping
    text = text.replace("\n", "\\N")
    return text


# ---------------------------------------------------------------------------
# Core ASS generator
# ---------------------------------------------------------------------------


def generate_ass(
    word_timings: list[dict],
    output_path: str | Path,
    video_duration: float,
    style: dict[str, Any] | None = None,
    video_width: int = 1080,
    video_height: int = 1920,
) -> Path:
    """
    Generate an ASS subtitle file from word-level timing data.

    Parameters
    ----------
    word_timings : list[dict]
        Each dict has keys: text/word (str), start (float seconds), end (float seconds).
    output_path : str | Path
        Where to write the .ass file.
    video_duration : float
        Maximum allowed end time — subtitles are clipped to this.
    style : dict, optional
        Subtitle style configuration. Keys from SUBTITLE_PRESETS or custom overrides.
        Use "preset" key to start from a preset, then override individual keys.
    video_width, video_height : int
        Video resolution for PlayResX/PlayResY.

    Returns
    -------
    Path to the generated .ass file.
    """
    output_path = Path(output_path)

    # Resolve style: start from preset, apply overrides
    style = style or {}
    preset_name = style.get("preset", "tiktok_karaoke")
    cfg = dict(SUBTITLE_PRESETS.get(preset_name, SUBTITLE_PRESETS["tiktok_karaoke"]))
    # Apply user overrides (excluding "preset" key)
    for k, v in style.items():
        if k != "preset" and v is not None:
            cfg[k] = v

    # Parse colors (support hex/rgba input)
    primary = _parse_color(cfg.get("primaryColor"), "&H0000FFFF")
    secondary = _parse_color(cfg.get("secondaryColor"), "&H00FFFFFF")
    outline_c = _parse_color(cfg.get("outlineColor"), "&H00000000")
    back_c = _parse_color(cfg.get("backColor"), "&H99000000")

    font_name = _resolve_font(cfg.get("fontFamily", "Montserrat"))
    font_size = cfg.get("fontSize", 48)
    bold = -1 if cfg.get("bold", True) else 0
    border_style = cfg.get("borderStyle", 3)
    outline = cfg.get("outline", 0)
    shadow = cfg.get("shadow", 0)
    alignment = cfg.get("alignment", 2)
    margin_l = cfg.get("marginL", 30)
    margin_r = cfg.get("marginR", 30)
    margin_v = cfg.get("marginV", 60)
    karaoke = cfg.get("karaoke", True)
    karaoke_mode = cfg.get("karaokeMode", "fill")  # "fill" (\kf) or "instant" (\k)
    animation = cfg.get("animation", "fade")  # "fade", "pop", "none"
    fade_dur = cfg.get("fadeDuration", 80)  # ms
    grouping = cfg.get("grouping", "pause")
    group_size = cfg.get("groupSize", 4)
    max_group_dur = cfg.get("maxGroupDuration", 1.5)
    pause_thresh = cfg.get("pauseThreshold", 0.3)

    # Normalize word timings — accept both "text" and "word" keys
    normalized: list[dict] = []
    for w in word_timings:
        text = w.get("text") or w.get("word", "")
        text = text.strip()
        if not text:
            continue
        start = float(w.get("start", 0))
        end = float(w.get("end", start + 0.1))
        # Clip to video duration
        if start >= video_duration:
            continue
        if end > video_duration:
            end = video_duration
        # Minimum word duration: 50ms
        if end - start < 0.05:
            end = min(start + 0.05, video_duration)
        normalized.append({"text": text, "start": start, "end": end})

    if not normalized:
        logger.warning("No valid word timings — writing empty ASS file")
        output_path.write_text(
            _empty_ass(video_width, video_height), encoding="utf-8-sig"
        )
        return output_path

    # Group words into phrases
    groups = group_words(
        normalized,
        strategy=grouping,
        group_size=group_size,
        max_duration=max_group_dur,
        pause_threshold=pause_thresh,
    )

    # Build ASS file content
    lines: list[str] = []

    # [Script Info]
    lines.append("[Script Info]")
    lines.append("ScriptType: v4.00+")
    lines.append("ScaledBorderAndShadow: yes")
    lines.append(f"PlayResX: {video_width}")
    lines.append(f"PlayResY: {video_height}")
    lines.append(f"LayoutResX: {video_width}")
    lines.append(f"LayoutResY: {video_height}")
    lines.append("WrapStyle: 2")
    lines.append("YCbCr Matrix: TV.709")
    lines.append("")

    # [V4+ Styles]
    lines.append("[V4+ Styles]")
    lines.append(
        "Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, "
        "OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, "
        "ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, "
        "Alignment, MarginL, MarginR, MarginV, Encoding"
    )
    lines.append(
        f"Style: Default,{font_name},{font_size},{primary},{secondary},"
        f"{outline_c},{back_c},{bold},0,0,0,"
        f"100,100,0,0,{border_style},{outline},{shadow},"
        f"{alignment},{margin_l},{margin_r},{margin_v},1"
    )
    lines.append("")

    # [Events]
    lines.append("[Events]")
    lines.append(
        "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
    )

    # Generate dialogue lines for each phrase group
    for gi, phrase_words in enumerate(groups):
        if not phrase_words:
            continue
        p_start = phrase_words[0]["start"]
        p_end = phrase_words[-1]["end"]
        # For grouped modes, enforce minimum display time
        if grouping != 'single_word':
            min_display = 0.3
            if p_end - p_start < min_display:
                next_start = video_duration
                for ngi in range(gi + 1, len(groups)):
                    if groups[ngi]:
                        next_start = groups[ngi][0]["start"]
                        break
                p_end = min(p_start + min_display, next_start, video_duration)
        # HARD CLIP: never overlap with next group (belt-and-suspenders)
        if grouping == 'single_word':
            for ngi in range(gi + 1, len(groups)):
                if groups[ngi]:
                    next_start = groups[ngi][0]["start"]
                    if p_end > next_start:
                        p_end = next_start
                    break
        # Clip to video duration
        if p_start >= video_duration:
            continue
        p_end = min(p_end, video_duration)

        start_str = _fmt_ass_time(p_start)
        end_str = _fmt_ass_time(p_end)

        # Build text with optional karaoke tags
        text_parts: list[str] = []

        # Animation prefix
        anim_tag = ""
        if animation == "fade" and fade_dur > 0:
            anim_tag = f"\\fad({fade_dur},{fade_dur})"
        elif animation == "pop":
            # Scale from 80% → 100% over 200ms
            anim_tag = "\\fscx80\\fscy80\\t(0,200,\\fscx100\\fscy100)"
        elif animation == "slam":
            # Scale from 150% → 100% over 80ms (fast slam/impact)
            anim_tag = "\\fscx150\\fscy150\\t(0,80,\\fscx100\\fscy100)"

        if karaoke:
            ktag = "\\kf" if karaoke_mode == "fill" else "\\k"
            for i, w in enumerate(phrase_words):
                word_dur_cs = max(5, round((w["end"] - w["start"]) * 100))
                escaped = _escape_ass(w["text"])
                if i == 0:
                    # First word includes animation tag
                    text_parts.append(f"{{{anim_tag}{ktag}{word_dur_cs}}}{escaped}")
                else:
                    text_parts.append(f"{{{ktag}{word_dur_cs}}}{escaped}")
            dialogue_text = " ".join(text_parts)
        else:
            # No karaoke — just show the full phrase
            phrase_text = " ".join(_escape_ass(w["text"]) for w in phrase_words)
            if anim_tag:
                dialogue_text = f"{{{anim_tag}}}{phrase_text}"
            else:
                dialogue_text = phrase_text

        lines.append(
            f"Dialogue: 0,{start_str},{end_str},Default,,0,0,0,,{dialogue_text}"
        )

    content = "\n".join(lines) + "\n"
    output_path.write_text(content, encoding="utf-8-sig")
    logger.info(
        "Generated ASS subtitles: %d phrases, %d words, style=%s, karaoke=%s",
        len(groups),
        len(normalized),
        preset_name,
        karaoke,
    )
    return output_path


def _empty_ass(w: int = 1080, h: int = 1920) -> str:
    """Return a minimal valid ASS file with no dialogue lines."""
    return (
        "[Script Info]\n"
        "ScriptType: v4.00+\n"
        f"PlayResX: {w}\n"
        f"PlayResY: {h}\n\n"
        "[V4+ Styles]\n"
        "Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, "
        "OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, "
        "ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, "
        "Alignment, MarginL, MarginR, MarginV, Encoding\n"
        "Style: Default,Sans,40,&H00FFFFFF,&H00FFFFFF,"
        "&H00000000,&H00000000,0,0,0,0,"
        "100,100,0,0,1,2,0,2,10,10,10,1\n\n"
        "[Events]\n"
        "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n"
    )
