# from fastapi import FastAPI
# from pydantic import BaseModel
# import requests
# import uuid
# import os
# import logging
# from typing import List
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from PIL import Image
# import imagehash
# import json
# from io import BytesIO
# from datetime import datetime
# from fastapi.responses import FileResponse
# from fastapi import FastAPI, Request
# from fastapi.responses import RedirectResponse, JSONResponse


# # Import your actual prediction logic
# from filter_images import predict_all

# # Setup logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# # Constants
# TEMP_DIR = "temp_images"
# HASH_FILE = "image_hashes.json"
# os.makedirs(TEMP_DIR, exist_ok=True)

# # Load or initialize hash dictionary
# if os.path.exists(HASH_FILE):
#     with open(HASH_FILE, "r") as f:
#         hash_dict = json.load(f)
# else:
#     hash_dict = {}

# # Convert string hashes to imagehash objects
# hash_dict = {imagehash.hex_to_hash(k): v for k, v in hash_dict.items()}
# hash_func = imagehash.phash

# # FastAPI app
# app = FastAPI(title="Image Filter API")

# # Request model
# class ImageBatchRequest(BaseModel):
#     image_urls: List[str]
#     activity: str

# # Response model
# class FilterResponse(BaseModel):
#     image_url: str
#     accepted: bool
#     reasons: List[str]

# # POST /filter endpoint with duplicate detection
# @app.post("/filter", response_model=List[FilterResponse], summary="Filter multiple images by URL")
# def filter_images(request: ImageBatchRequest):
#     results = []
#     activity = request.activity
#     print(activity)
 
#     def process_url(image_url: str) -> FilterResponse:
#         try:
#             logger.info(f"Received image URL: {image_url}")
#             response = requests.get(image_url, timeout=10)
#             if response.status_code != 200:
#                 return FilterResponse(
#                     image_url=image_url,
#                     accepted=False,
#                     reasons=["Failed to download image from URL."]
#                 )

#             ext = image_url.split(".")[-1].split("?")[0].lower()
#             if ext not in ["jpg", "jpeg", "png"]:
#                 return FilterResponse(
#                     image_url=image_url,
#                     accepted=False,
#                     reasons=["Unsupported image format."]
#                 )

#             filename = f"{uuid.uuid4()}.{ext}"
#             filepath = os.path.join(TEMP_DIR, filename)
#             with open(filepath, "wb") as f:
#                 f.write(response.content)

#             # Check for duplicate using perceptual hash
#             try:
#                 img = Image.open(BytesIO(response.content))
#                 img_hash = hash_func(img)

#                 # if img_hash in hash_dict:
#                 #     original_url = hash_dict[img_hash]
#                 #     os.remove(filepath)
#                 #     return FilterResponse(
#                 #         image_url=image_url,
#                 #         accepted=False,
#                 #         reasons=[f"Duplicate image detected. Matches previously processed image: {original_url}"]
#                 #     )

#                 # Save new hash
#                 hash_dict[img_hash] = image_url
#                 with open(HASH_FILE, "w") as f:
#                     json.dump({str(k): v for k, v in hash_dict.items()}, f, indent=2)

#             except Exception as e:
#                 logger.warning(f"Hashing failed for {image_url}: {e}")

#             # Run prediction
#             result = predict_all(filepath,activity)
#             os.remove(filepath)

#             return FilterResponse(
#                 image_url=image_url,
#                 accepted=result.get("accepted", False),
#                 reasons=result.get("reasons", ["Image rejected."])
#             )

#         except Exception as e:
#             logger.error(f"Error processing image {image_url}: {str(e)}")
#             return FilterResponse(
#                 image_url=image_url,
#                 accepted=False,
#                 reasons=[f"Error: {str(e)}"]
#             )

#     # Run predictions in parallel
#     with ThreadPoolExecutor(max_workers=4) as executor:
#         futures = [executor.submit(process_url, url) for url in request.image_urls]
#         for future in as_completed(futures):
#             results.append(future.result())

#     return results

# # POST /filter_dev endpoint (bypass logic)
# @app.post("/filter_dev", response_model=List[FilterResponse], summary="Filter multiple images by URL (bypass logic)")
# def filter_images_bypass(request: ImageBatchRequest):
#     results = []
#     for image_url in request.image_urls:
#         logger.info(f"Bypass filter received image URL: {image_url}")
#         results.append(FilterResponse(
#             image_url=image_url,
#             accepted=True,
#             reasons=[]
#         ))
#     return results



# # Public Spaces URL of your file
# SPACES_URL = "https://grq-img-store.sgp1.digitaloceanspaces.com/prod/classifiers_zip/latest.zip"

# @app.get("/download")
# def download_zip(request: Request):
#     # Forward If-Modified-Since header to Spaces
#     headers = {}
#     if_modified_since = request.headers.get("If-Modified-Since")
#     if if_modified_since:
#         headers["If-Modified-Since"] = if_modified_since

#     # Just redirect client to Spaces URL
#     return RedirectResponse(SPACES_URL, headers=headers)




from fastapi import FastAPI, Request
from pydantic import BaseModel
import requests
import uuid
import os
import logging
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import imagehash
import json
from io import BytesIO
from datetime import datetime
from fastapi.responses import RedirectResponse, JSONResponse

# Import your actual prediction logic
from filter_images import predict_all

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
TEMP_DIR = "temp_images"
HASH_FILE = "image_hashes.json"
os.makedirs(TEMP_DIR, exist_ok=True)

# Safe JSON loader
def safe_load_json(path):
    try:
        with open(path, "r") as f:
            content = f.read().strip()
            if not content:
                return {}
            try:
                return json.loads(content)  # normal case
            except json.JSONDecodeError:
                # fallback: parse line by line
                data = {}
                for i, line in enumerate(content.splitlines()):
                    try:
                        obj = json.loads(line)
                        if isinstance(obj, dict):
                            data.update(obj)
                        else:
                            data[str(i)] = obj
                    except Exception:
                        continue
                return data
    except FileNotFoundError:
        return {}

# Load or initialize hash dictionary
if os.path.exists(HASH_FILE):
    hash_dict = safe_load_json(HASH_FILE)
else:
    hash_dict = {}

# Convert string hashes to imagehash objects
hash_dict = {imagehash.hex_to_hash(k): v for k, v in hash_dict.items()}
hash_func = imagehash.phash

# FastAPI app
app = FastAPI(title="Image Filter API")

# Request model
class ImageBatchRequest(BaseModel):
    image_urls: List[str]
    activity: str

# Response model
class FilterResponse(BaseModel):
    image_url: str
    accepted: bool
    reasons: List[str]

# POST /filter endpoint with duplicate detection
@app.post("/filter", response_model=List[FilterResponse], summary="Filter multiple images by URL")
def filter_images(request: ImageBatchRequest):
    results = []
    activity = request.activity
    print(activity)

    def process_url(image_url: str) -> FilterResponse:
        try:
            logger.info(f"Received image URL: {image_url}")
            response = requests.get(image_url, timeout=10)
            if response.status_code != 200:
                return FilterResponse(
                    image_url=image_url,
                    accepted=False,
                    reasons=["Failed to download image from URL."]
                )

            ext = image_url.split(".")[-1].split("?")[0].lower()
            if ext not in ["jpg", "jpeg", "png"]:
                return FilterResponse(
                    image_url=image_url,
                    accepted=False,
                    reasons=["Unsupported image format."]
                )

            filename = f"{uuid.uuid4()}.{ext}"
            filepath = os.path.join(TEMP_DIR, filename)
            with open(filepath, "wb") as f:
                f.write(response.content)

            # Check for duplicate using perceptual hash
            try:
                img = Image.open(BytesIO(response.content))
                img_hash = hash_func(img)

                # Duplicate check disabled in your current logic
                if img_hash in hash_dict:
                    original_url = hash_dict[img_hash]
                    os.remove(filepath)
                    return FilterResponse(
                        image_url=image_url,
                        accepted=False,
                        reasons=[f"Duplicate image detected. Matches previously processed image: {original_url}"]
                    )

                # Save new hash
                hash_dict[img_hash] = image_url
                with open(HASH_FILE, "w") as f:
                    json.dump({str(k): v for k, v in hash_dict.items()}, f, indent=2)

            except Exception as e:
                logger.warning(f"Hashing failed for {image_url}: {e}")

            # Run prediction
            result = predict_all(filepath, activity)
            os.remove(filepath)

            return FilterResponse(
                image_url=image_url,
                accepted=result.get("accepted", False),
                reasons=result.get("reasons", ["Image rejected."])
            )

        except Exception as e:
            logger.error(f"Error processing image {image_url}: {str(e)}")
            return FilterResponse(
                image_url=image_url,
                accepted=False,
                reasons=[f"Error: {str(e)}"]
            )

    # Run predictions in parallel
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(process_url, url) for url in request.image_urls]
        for future in as_completed(futures):
            results.append(future.result())

    return results

# POST /filter_dev endpoint (bypass logic)
@app.post("/filter_dev", response_model=List[FilterResponse], summary="Filter multiple images by URL (bypass logic)")
def filter_images_bypass(request: ImageBatchRequest):
    results = []
    for image_url in request.image_urls:
        logger.info(f"Bypass filter received image URL: {image_url}")
        results.append(FilterResponse(
            image_url=image_url,
            accepted=True,
            reasons=[]
        ))
    return results

# Public Spaces URL of your file
SPACES_URL = "https://grq-img-store.sgp1.digitaloceanspaces.com/prod/classifiers_zip/latest.zip"

@app.get("/download")
def download_zip(request: Request):
    # Forward If-Modified-Since header to Spaces
    headers = {}
    if_modified_since = request.headers.get("If-Modified-Since")
    if if_modified_since:
        headers["If-Modified-Since"] = if_modified_since

    # Just redirect client to Spaces URL
    return RedirectResponse(SPACES_URL, headers=headers)

