import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw, ImageFont
import pillow_heif
import os
import shutil

from datetime import datetime
from PIL import Image, ImageEnhance, ImageDraw, ImageFont, ExifTags

# Register HEIF opener globally
pillow_heif.register_heif_opener()

class ImageUtils:
    def __init__(self):
        self.use_face_recognition_lib = False
        try:
            import face_recognition
            self.face_recognition = face_recognition
            self.use_face_recognition_lib = True
        except ImportError:
            print("Warning: 'face_recognition' library not found. Falling back to OpenCV Haar Cascade.")
            self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

    def load_image(self, file_path: str) -> Image.Image:
        """Loads an image, handling HEIC conversion to RGB automatically."""
        try:
            img = Image.open(file_path)
            # Handle orientation from EXIF
            try:
                for orientation in ExifTags.TAGS.keys():
                    if ExifTags.TAGS[orientation] == 'Orientation':
                        break
                exif = img._getexif()
                if exif is not None:
                    orientation = exif.get(orientation)
                    if orientation == 3:
                        img = img.rotate(180, expand=True)
                    elif orientation == 6:
                        img = img.rotate(270, expand=True)
                    elif orientation == 8:
                        img = img.rotate(90, expand=True)
            except Exception:
                pass # No EXIF or error reading it

            if img.mode != 'RGB':
                img = img.convert('RGB')
            return img
        except Exception as e:
            raise ValueError(f"Failed to load image: {e}")

    def save_image(self, img: Image.Image, output_path: str, quality: int = 95):
        """Saves image to path, typically as JPEG."""
        img.save(output_path, "JPEG", quality=quality)
        return output_path

    def auto_enhance(self, img: Image.Image, brightness_factor: float = 1.05, sharpness_factor: float = 1.15) -> Image.Image:
        """Applies basic auto-enhancement logic."""
        enhancer_b = ImageEnhance.Brightness(img)
        img = enhancer_b.enhance(brightness_factor)
        enhancer_s = ImageEnhance.Sharpness(img)
        img = enhancer_s.enhance(sharpness_factor)
        return img

    def get_smart_crop_box(self, img: Image.Image, target_ratio: float) -> tuple:
        """
        Calculates smart crop box focusing on faces.
        Optimized by resizing image for detection step.
        """
        original_width, original_height = img.size
        
        # Optimization: Resize for face detection if image is too large
        detect_width = 800
        scale = 1.0
        if original_width > detect_width:
            scale = detect_width / original_width
            detect_height = int(original_height * scale)
            img_for_detect = img.resize((detect_width, detect_height))
        else:
            img_for_detect = img

        cv_img = np.array(img_for_detect)
        
        faces = []
        if self.use_face_recognition_lib:
            # face_recognition takes RGB
            face_locations = self.face_recognition.face_locations(cv_img)
            # (top, right, bottom, left) -> (x, y, w, h)
            faces = [(left, top, right - left, bottom - top) for top, right, bottom, left in face_locations]
        else:
            # OpenCV needs Gray
            gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
            faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)

        # Scale faces back to original coordinates
        if scale != 1.0:
            faces = [(int(x / scale), int(y / scale), int(w / scale), int(h / scale)) for x, y, w, h in faces]

        # Calculate Center of Interest
        if len(faces) > 0:
            center_x = sum([x + w/2 for x, y, w, h in faces]) / len(faces)
            center_y = sum([y + h/2 for x, y, h, w in faces]) / len(faces) # Note: order logic fix if needed, but x,y,w,h is standard
        else:
            center_x = original_width / 2
            center_y = original_height / 2

        # Calculate crop box based on Target Ratio (Width / Height)
        # If image is 4000x3000 (1.33) and target is 4x6 (0.66) -> Cut width
        # If image is 3000x4000 (0.75) and target is 4x6 (0.66) -> Cut width (still wider than target)
        
        current_ratio = original_width / original_height
        
        if current_ratio > target_ratio:
            # Image is wider than target: Fix height, crop width
            new_height = original_height
            new_width = int(original_height * target_ratio)
        else:
            # Image is taller/narrower than target: Fix width, crop height
            new_width = original_width
            new_height = int(original_width / target_ratio)

        left = center_x - (new_width / 2)
        top = center_y - (new_height / 2)

        # Boundary checks
        if left < 0: left = 0
        if top < 0: top = 0
        if left + new_width > original_width: left = original_width - new_width
        if top + new_height > original_height: top = original_height - new_height

        return (int(left), int(top), int(left + new_width), int(top + new_height))

    def _add_white_padding(self, img: Image.Image, target_ratio: float) -> Image.Image:
        """
        Adds white padding to the image to match the target ratio (Image Full).
        The image is NOT cropped; instead, white bars are added.
        """
        original_width, original_height = img.size
        current_ratio = original_width / original_height
        
        if abs(current_ratio - target_ratio) < 0.01:
            return img

        if current_ratio > target_ratio:
            # Image is wider than target: Add padding to top/bottom to increase height
            # target_ratio = w / h_new  => h_new = w / target_ratio
            new_width = original_width
            new_height = int(original_width / target_ratio)
        else:
            # Image is narrower than target: Add padding to left/right to increase width
            # target_ratio = w_new / h => w_new = h * target_ratio
            new_height = original_height
            new_width = int(original_height * target_ratio)

        # Create new white image
        new_img = Image.new("RGB", (new_width, new_height), (255, 255, 255))
        
        # Paste original image in center
        paste_x = (new_width - original_width) // 2
        paste_y = (new_height - original_height) // 2
        new_img.paste(img, (paste_x, paste_y))
        
        return new_img

    def _draw_date_text(self, img: Image.Image, date_text: str = None) -> Image.Image:
        """Images date text on bottom right."""
        if not date_text:
            # Try to get date from EXIF 'DateTimeOriginal'
            try:
                exif = img._getexif()
                if exif:
                    # 36867 is DateTimeOriginal
                    date_str = exif.get(36867)
                    if date_str:
                        # Format: YYYY:MM:DD HH:MM:SS
                        dt = datetime.strptime(date_str, "%Y:%m:%d %H:%M:%S")
                        date_text = dt.strftime("%Y. %m. %d.")
            except Exception:
                pass
        
        if not date_text:
            return img # Could not determine date

        draw = ImageDraw.Draw(img)
        w, h = img.size
        
        # Dynamic font size: ~3% of height
        font_size = int(h * 0.03) 
        try:
            # Try to load a known font if available, else default
            # For Windows, can try "arial.ttf" or "malgun.ttf"
            font = ImageFont.truetype("arial.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        # Calculate text size using standard approach for compatibility
        # For newer Pillow: draw.textbbox; older: draw.textsize
        try:
            left, top, right, bottom = draw.textbbox((0, 0), date_text, font=font)
            text_w = right - left
            text_h = bottom - top
        except AttributeError:
            text_w, text_h = draw.textsize(date_text, font=font)
            
        margin_x = int(w * 0.05)
        margin_y = int(h * 0.05)
        
        x = w - text_w - margin_x
        y = h - text_h - margin_y
        
        # Draw shadow/outline for visibility
        shadow_color = (0, 0, 0)
        text_color = (255, 255, 255) # White text with slight transparency (simulated as RGB) -> actually PIL RGB is no alpha, so just White
        text_color = (255, 255, 255)
        
        # Simple shadow
        draw.text((x+2, y+2), date_text, font=font, fill=shadow_color)
        draw.text((x, y), date_text, font=font, fill=text_color)
        
        return img

    def process_and_save(self, input_path: str, output_path: str, fill_mode: str = "paper_full", 
                         target_ratio_w_h: float = 4/6, do_enhance: bool = True, 
                         print_date: bool = False, date_text: str = None):
        """
        Main worker function.
        """
        img = self.load_image(input_path)

        if fill_mode == "paper_full":
            crop_box = self.get_smart_crop_box(img, target_ratio_w_h)
            img = img.crop(crop_box)
        elif fill_mode == "image_full":
            # Add white padding to fit target ratio
            img = self._add_white_padding(img, target_ratio_w_h)

        if do_enhance:
            img = self.auto_enhance(img)

        if print_date:
            img = self._draw_date_text(img, date_text)

        self.save_image(img, output_path)
        return output_path

# Singleton instance
image_utils = ImageUtils()
