"""
Audio/video synchronization with production-grade atempo chaining.

Updated with research findings for proper WSOLA-based speed adjustment.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass

logger = logging.getLogger(__name__)


def build_atempo_filter(speed_factor: float, max_chain: int = 4) -> str:
    """
    Build FFmpeg atempo filter chain for any speed factor.
    
    atempo range: [0.5, 2.0] per filter instance.
    Chains multiple filters to reach any speed.
    
    Args:
        speed_factor: Playback speed factor (>1 = faster, <1 = slower)
        max_chain: Maximum chaining depth (prevents infinite loops)
    
    Returns: FFmpeg filter string (e.g., "atempo=1.25" or "atempo=2.0,atempo=1.5")
    
    Examples:
        0.25x speed → "atempo=0.5,atempo=0.5"
        1.0x speed  → "" (empty, no adjustment needed)
        1.25x speed → "atempo=1.25"
        3.0x speed  → "atempo=2.0,atempo=1.5"
        4.0x speed  → "atempo=2.0,atempo=2.0"
    """
    # No adjustment needed
    if abs(speed_factor - 1.0) < 1e-6:
        return ""
    
    # Already in safe range, single filter
    if 0.5 <= speed_factor <= 2.0:
        return f"atempo={speed_factor:.6f}"
    
    # Build chained filters
    filters = []
    remaining = speed_factor
    
    # Slow down: chain atempo=0.5 until remainder is in range
    while remaining < 0.5:
        filters.append("atempo=0.5")
        remaining = remaining / 0.5  # Divide by 0.5 = multiply by 2
        if len(filters) > max_chain:
            logger.warning(f"Speed factor {speed_factor} required >{max_chain} chain depth, clamping")
            return ",".join(filters[:max_chain])
    
    # Speed up: chain atempo=2.0 until remainder is in range
    while remaining > 2.0:
        filters.append("atempo=2.0")
        remaining = remaining / 2.0
        if len(filters) > max_chain:
            logger.warning(f"Speed factor {speed_factor} required >{max_chain} chain depth, clamping")
            return ",".join(filters[:max_chain])
    
    # Add final partial filter if remainder is not 1.0
    if abs(remaining - 1.0) > 1e-6:
        filters.append(f"atempo={remaining:.6f}")
    
    return ",".join(filters)


@dataclass
class DurationMismatch:
    """Describes audio/video duration mismatch and correction strategy."""
    video_duration: float
    audio_duration: float
    difference_seconds: float
    percentage_diff: float
    requires_adjustment: bool
    strategy: str  # "audio_adjust", "video_adjust", "none"
    adjustment_factor: float
    
    def __post_init__(self):
        if self.audio_duration <= 0 or self.video_duration <= 0:
            self.requires_adjustment = False
            self.strategy = "none"
            self.adjustment_factor = 1.0
            self.difference_seconds = 0.0
            self.percentage_diff = 0.0
            return
        
        self.difference_seconds = abs(self.audio_duration - self.video_duration)
        self.percentage_diff = (self.difference_seconds / self.video_duration * 100)
        
        # Decision logic:
        # Within 5% tolerance = no adjustment
        if abs(self.percentage_diff) < 5.0:
            self.requires_adjustment = False
            self.strategy = "none"
            self.adjustment_factor = 1.0
            return
        
        self.requires_adjustment = True
        if self.video_duration < self.audio_duration:
            # Video shorter than audio — extend video to match (primary strategy)
            self.strategy = "video_adjust"
            self.adjustment_factor = self.audio_duration / self.video_duration
        else:
            # Video longer than audio — adjust audio speed slightly
            self.strategy = "audio_adjust"
            self.adjustment_factor = self.video_duration / self.audio_duration  # atempo value


class AudioSyncCalculator:
    """Production-grade audio/video synchronization calculations."""
    
    @staticmethod
    def calculate_audio_filters(mismatch: DurationMismatch) -> list[str]:
        """
        Build FFmpeg audio filters to handle duration mismatch.
        
        Returns: List of filter strings (e.g., ["atempo=1.25", "aresample=async=0"])
        """
        if not mismatch.requires_adjustment:
            return []
        
        filters = []
        
        if mismatch.strategy == "audio_adjust":
            # Use atempo for pitch-preserving speed adjustment
            atempo_chain = build_atempo_filter(mismatch.adjustment_factor)
            if atempo_chain:
                # Prepend aresample to handle async issues (from research)
                filters.append("aresample=async=0")
                filters.append(atempo_chain)
        
        return filters
    
    @staticmethod
    def calculate_video_filters(
        mismatch: DurationMismatch,
        fps: int = 30
    ) -> list[str]:
        """
        Build FFmpeg video filters to handle duration mismatch.
        
        For video adjustment, use setpts to skip/repeat frames.
        """
        if not mismatch.requires_adjustment:
            return []
        
        if mismatch.strategy != "video_adjust":
            return []
        
        filters = []
        speed = mismatch.adjustment_factor
        
        # setpts rewrites presentation timestamps
        # N = frame number, FRAME_RATE = fps (default 30)
        # setpts=PTS/speed_factor for speedup, PTS*speed_factor for slowdown
        pts_mult = 1.0 / speed
        filters.append(f"setpts='{pts_mult:.6f}*PTS'/TB")
        
        return filters
    
    @staticmethod
    def calculate_padding_seconds(mismatch: DurationMismatch) -> float:
        """
        Calculate silence padding if audio is shorter than video.
        
        Note: With atempo adjustment, we usually don't need padding.
        This is for fallback scenarios.
        """
        if mismatch.audio_duration >= mismatch.video_duration:
            return 0.0
        
        # If using audio_adjust (atempo), no padding needed
        if mismatch.strategy == "audio_adjust":
            return 0.0
        
        return mismatch.video_duration - mismatch.audio_duration


class DurationReconciler:
    """High-level API for audio/video sync reconciliation."""
    
    def __init__(self, audio_duration: float, video_duration: float):
        self.audio_duration = audio_duration
        self.video_duration = video_duration
        self.mismatch = DurationMismatch(
            video_duration=video_duration,
            audio_duration=audio_duration,
            difference_seconds=0.0,
            percentage_diff=0.0,
            requires_adjustment=False,
            strategy="none",
            adjustment_factor=1.0,
        )
    
    def get_mismatch(self) -> DurationMismatch:
        """Get mismatch analysis."""
        return self.mismatch
    
    def get_audio_filters(self) -> list[str]:
        """Get FFmpeg filters to apply to audio stream."""
        return AudioSyncCalculator.calculate_audio_filters(self.mismatch)
    
    def get_video_filters(self, fps: int = 30) -> list[str]:
        """Get FFmpeg filters to apply to video stream."""
        return AudioSyncCalculator.calculate_video_filters(self.mismatch, fps)
    
    def get_padding_seconds(self) -> float:
        """Get silence padding needed (if applicable)."""
        return AudioSyncCalculator.calculate_padding_seconds(self.mismatch)
    
    def summary(self) -> dict:
        """Return human-readable summary."""
        return {
            "audio_duration": round(self.audio_duration, 3),
            "video_duration": round(self.video_duration, 3),
            "difference_seconds": round(self.mismatch.difference_seconds, 3),
            "percentage_diff": round(self.mismatch.percentage_diff, 1),
            "requires_adjustment": self.mismatch.requires_adjustment,
            "strategy": self.mismatch.strategy,
            "adjustment_factor": round(self.mismatch.adjustment_factor, 4),
            "audio_filters": self.get_audio_filters(),
            "video_filters": self.get_video_filters(),
            "padding_seconds": round(self.get_padding_seconds(), 3),
        }
