# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import copy
import io
import os
import warnings
from pathlib import Path
from typing import List, Optional

import numpy as np

from ....modules.doc_vlm.model_list import MODELS
from ....utils import logging
from ....utils.deps import require_genai_client_plugin
from ....utils.device import TemporaryDeviceChanger
from ...common.batch_sampler import DocVLMBatchSampler
from ...utils.misc import is_bfloat16_available
from ...utils.model_paths import get_model_paths
from ..base import BasePredictor
from .result import DocVLMResult


class DocVLMPredictor(BasePredictor):

    entities = MODELS
    model_group = {
        "PP-DocBee": {"PP-DocBee-2B", "PP-DocBee-7B"},
        "PP-DocBee2": {"PP-DocBee2-3B"},
        "PP-Chart2Table": {"PP-Chart2Table"},
        "PaddleOCR-VL": {"PaddleOCR-VL-0.9B", "PaddleOCR-VL-1.5-0.9B"},
    }

    def __init__(self, *args, **kwargs):
        """Initializes DocVLMPredictor.
        Args:
            *args: Arbitrary positional arguments passed to the superclass.
            **kwargs: Arbitrary keyword arguments passed to the superclass.
        """
        super().__init__(*args, **kwargs)

        if self.batch_sampler.batch_size == -1:
            self.batch_sampler.batch_size = self._determine_batch_size()

        if self._use_local_model:
            if self._use_static_model:
                raise RuntimeError("Static graph models are not supported")
            self.device = kwargs.get("device", None)
            if is_bfloat16_available(self.device):
                self.dtype = "bfloat16"
            else:
                self.dtype = "float32"

            self.infer, self.processor = self._build(**kwargs)

            if (
                self.model_name in self.model_group["PaddleOCR-VL"]
                and self.batch_sampler.batch_size > 1
            ):
                logging.warning(
                    f"Currently, the {repr(self.model_name)} local model only supports batch size of 1. The batch size will be updated to 1."
                )
                self.batch_sampler.batch_size = 1

    def _build_batch_sampler(self):
        """Builds and returns an DocVLMBatchSampler instance.

        Returns:
            DocVLMBatchSampler: An instance of DocVLMBatchSampler.
        """
        return DocVLMBatchSampler(self.model_name)

    def _get_result_class(self):
        """Returns the result class, DocVLMResult.

        Returns:
            type: The DocVLMResult class.
        """
        return DocVLMResult

    def _build(self, **kwargs):
        """Build the model, and correspounding processor on the configuration.

        Returns:
            model: An instance of Paddle model, could be either a dynamic model or a static model.
            processor: The correspounding processor for the model.
        """
        from .modeling import (
            PaddleOCRVLForConditionalGeneration,
            PPChart2TableInference,
            PPDocBee2Inference,
            PPDocBeeInference,
        )

        # build processor
        processor = self.build_processor()

        # build model
        if self.model_name in self.model_group["PP-DocBee"]:
            if kwargs.get("use_hpip", False):
                warnings.warn(
                    "The PP-DocBee series does not support `use_hpip=True` for now."
                )
            with TemporaryDeviceChanger(self.device):
                model = PPDocBeeInference.from_pretrained(
                    self.model_dir, dtype=self.dtype
                )
        elif self.model_name in self.model_group["PP-Chart2Table"]:
            if kwargs.get("use_hpip", False):
                warnings.warn(
                    "The PP-Chart2Table series does not support `use_hpip=True` for now."
                )
            with TemporaryDeviceChanger(self.device):
                model_path = get_model_paths(self.model_dir)

                if "safetensors" in model_path:
                    model = PPChart2TableInference.from_pretrained(
                        self.model_dir,
                        dtype=self.dtype,
                        pad_token_id=processor.tokenizer.eos_token_id,
                        use_safetensors=True,
                        convert_from_hf=True,
                    )
                else:
                    model = PPChart2TableInference.from_pretrained(
                        self.model_dir,
                        dtype=self.dtype,
                        pad_token_id=processor.tokenizer.eos_token_id,
                    )

        elif self.model_name in self.model_group["PP-DocBee2"]:
            if kwargs.get("use_hpip", False):
                warnings.warn(
                    "The PP-Chart2Table series does not support `use_hpip=True` for now."
                )
            with TemporaryDeviceChanger(self.device):
                model = PPDocBee2Inference.from_pretrained(
                    self.model_dir,
                    dtype=self.dtype,
                )
        elif self.model_name in self.model_group["PaddleOCR-VL"]:
            if kwargs.get("use_hpip", False):
                warnings.warn(
                    "The PaddelOCR-VL series does not support `use_hpip=True` for now."
                )
            with TemporaryDeviceChanger(self.device):
                model = PaddleOCRVLForConditionalGeneration.from_pretrained(
                    self.model_dir,
                    dtype=self.dtype,
                    convert_from_hf=True,
                )
        else:
            raise NotImplementedError(f"Model {self.model_name} is not supported.")

        return model, processor

    def _determine_batch_size(self):
        if self._model_name in ("PaddleOCR-VL-0.9B", "PaddleOCR-VL-1.5-0.9B"):
            batch_size = 1
            if not self._use_local_model:
                batch_size = 4096
            logging.debug(
                f"The batch size of {self._model_name} is determined to be {batch_size}."
            )
            return batch_size
        else:
            raise RuntimeError(f"Could not determine batch size for {self._model_name}")

    def process(
        self,
        data: List[dict],
        max_new_tokens: Optional[int] = None,
        skip_special_tokens: Optional[bool] = None,
        repetition_penalty: Optional[float] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        min_pixels: Optional[int] = None,
        max_pixels: Optional[int] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
    ):
        """
        Process a batch of data through the preprocessing, inference, and postprocessing.

        Args:
            data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).

        Returns:
            dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
        """
        # TODO: Sampling settings
        # FIXME: When `skip_special_tokens` is `True`, the results from different backends may differ.

        assert all(isinstance(i, dict) for i in data)

        if self._use_local_model:
            src_data = copy.copy(data)
            # preprocess
            if self.model_name in self.model_group["PaddleOCR-VL"]:
                data = self.processor.preprocess(
                    data, min_pixels=min_pixels, max_pixels=max_pixels
                )
            else:
                data = self.processor.preprocess(data)
                if min_pixels is not None:
                    warnings.warn(
                        f"`min_pixels` is currently not supported by the {repr(self.model_name)} model and will be ignored."
                    )
                if max_pixels is not None:
                    warnings.warn(
                        f"`max_pixels` is currently not supported by the {repr(self.model_name)} model and will be ignored."
                    )

            data = self._switch_inputs_to_device(data)

            # do infer
            generate_kwargs = {}
            if max_new_tokens is not None:
                generate_kwargs["max_new_tokens"] = max_new_tokens
            elif self.model_name in self.model_group["PaddleOCR-VL"]:
                generate_kwargs["max_new_tokens"] = 8192
            if repetition_penalty is not None:
                warnings.warn(
                    "`repetition_penalty` is currently not supported by the local model and will be ignored."
                )
            if temperature is not None:
                warnings.warn(
                    "`temperature` is currently not supported by the local model and will be ignored."
                )
            if top_p is not None:
                warnings.warn(
                    "`top_p` is currently not supported by the local model and will be ignored."
                )
            if use_cache is not None:
                generate_kwargs["use_cache"] = use_cache
            with TemporaryDeviceChanger(self.device):
                preds = self.infer.generate(
                    data,
                    **generate_kwargs,
                )

            # postprocess
            postprocess_kwargs = {}
            if skip_special_tokens is not None:
                postprocess_kwargs["skip_special_tokens"] = skip_special_tokens
            preds = self.processor.postprocess(preds, **postprocess_kwargs)
        else:
            require_genai_client_plugin()

            src_data = data

            preds = self._genai_client_process(
                data,
                max_new_tokens=max_new_tokens,
                skip_special_tokens=skip_special_tokens,
                repetition_penalty=repetition_penalty,
                temperature=temperature,
                top_p=top_p,
                min_pixels=min_pixels,
                max_pixels=max_pixels,
            )

        result_dict = self._format_result_dict(preds, src_data)
        return result_dict

    def build_processor(self, **kwargs):
        from ..common.tokenizer import (
            LlamaTokenizer,
            MIXQwen2_5_Tokenizer,
            MIXQwen2Tokenizer,
            QWenTokenizer,
        )
        from ..common.tokenizer.tokenizer_utils import ChatTemplate
        from .processors import (
            GOTImageProcessor,
            PaddleOCRVLProcessor,
            PPChart2TableProcessor,
            PPDocBee2Processor,
            PPDocBeeProcessor,
            Qwen2_5_VLImageProcessor,
            Qwen2VLImageProcessor,
            SiglipImageProcessor,
        )

        if self.model_name in self.model_group["PP-DocBee"]:
            image_processor = Qwen2VLImageProcessor()
            tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
            return PPDocBeeProcessor(
                image_processor=image_processor, tokenizer=tokenizer
            )
        elif self.model_name in self.model_group["PP-Chart2Table"]:
            image_processor = GOTImageProcessor(1024)
            tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
            return PPChart2TableProcessor(
                image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
            )
        elif self.model_name in self.model_group["PP-DocBee2"]:
            image_processor = Qwen2_5_VLImageProcessor()
            tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(self.model_dir)
            return PPDocBee2Processor(
                image_processor=image_processor, tokenizer=tokenizer
            )
        elif self.model_name in self.model_group["PaddleOCR-VL"]:
            image_processor = SiglipImageProcessor.from_pretrained(self.model_dir)
            vocab_file = str(Path(self.model_dir, "tokenizer.model"))
            tokenizer = LlamaTokenizer.from_pretrained(
                self.model_dir, vocab_file=vocab_file
            )
            # HACK
            chat_template_file = Path(self.model_dir, "chat_template.jinja")
            tokenizer.chat_template = ChatTemplate._compile_jinja_template(
                chat_template_file.read_text(encoding="utf-8")
            )
            return PaddleOCRVLProcessor(
                image_processor=image_processor,
                tokenizer=tokenizer,
            )
        else:
            raise NotImplementedError

    def _format_result_dict(self, model_preds, src_data):
        if not isinstance(model_preds, list):
            model_preds = [model_preds]
        if not isinstance(src_data, list):
            src_data = [src_data]
        input_info = []
        for data in src_data:
            image = data.get("image", None)
            if isinstance(image, str):
                data["input_path"] = image
            input_info.append(data)
        if len(model_preds) != len(input_info):
            raise ValueError(
                f"Model predicts {len(model_preds)} results while src data has {len(input_info)} samples."
            )

        rst_format_dict = {k: [] for k in input_info[0].keys()}
        rst_format_dict["result"] = []

        for data_sample, model_pred in zip(input_info, model_preds):
            for k in data_sample.keys():
                rst_format_dict[k].append(data_sample[k])
            rst_format_dict["result"].append(model_pred)

        return rst_format_dict

    def _infer_dynamic_forward_device(self, device):
        """infer the forward device for dynamic graph model"""
        import GPUtil

        from ....utils.device import parse_device

        if device is None:
            return None
        if "cpu" in device.lower():
            return "cpu"
        device_type, device_ids = parse_device(device)

        cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
        if cuda_visible_devices is None:
            env_gpu_num = len(GPUtil.getGPUs())
            cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
        env_device_ids = cuda_visible_devices.split(",")
        for env_device_id in env_device_ids:
            if not env_device_id.isdigit():
                raise ValueError(
                    f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
                )

        if max(device_ids) >= len(env_device_ids):
            raise ValueError(
                f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
            )

        rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
        return device_type + ":" + ",".join(rst_global_gpu_ids)

    def _switch_inputs_to_device(self, input_dict):
        """Switch the input to the specified device"""
        import paddle

        if self.device is None:
            return input_dict
        rst_dict = {
            k: (
                paddle.to_tensor(input_dict[k], place=self.device)
                if isinstance(input_dict[k], paddle.Tensor)
                else input_dict[k]
            )
            for k in input_dict
        }
        return rst_dict

    def _genai_client_process(
        self,
        data,
        max_new_tokens,
        skip_special_tokens,
        repetition_penalty,
        temperature,
        top_p,
        min_pixels,
        max_pixels,
    ):
        futures = []
        if self._genai_client.backend == "llama-cpp-server":
            image_format = "PNG"
        else:
            image_format = "JPEG"
        try:
            for item in data:
                image = item["image"]
                if isinstance(image, str):
                    if image.startswith("http://") or image.startswith("https://"):
                        image_url = image
                    else:
                        from PIL import Image

                        with Image.open(image) as img:
                            img = img.convert("RGB")
                            with io.BytesIO() as buf:
                                img.save(buf, format=image_format)
                                image_url = (
                                    f"data:image/{image_format.lower()};base64,"
                                    + base64.b64encode(buf.getvalue()).decode("ascii")
                                )
                elif isinstance(image, np.ndarray):
                    import cv2
                    from PIL import Image

                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    img = Image.fromarray(image)
                    with io.BytesIO() as buf:
                        img.save(buf, format=image_format)
                        image_url = (
                            f"data:image/{image_format.lower()};base64,"
                            + base64.b64encode(buf.getvalue()).decode("ascii")
                        )
                else:
                    raise TypeError(f"Not supported image type: {type(image)}")

                if self._genai_client.backend == "fastdeploy-server":
                    kwargs = {
                        "temperature": 1 if temperature is None else temperature,
                        "top_p": 0 if top_p is None else top_p,
                    }
                else:
                    kwargs = {
                        "temperature": 0 if temperature is None else temperature,
                    }
                    if top_p is not None:
                        kwargs["top_p"] = top_p

                if self._genai_client.backend in ["mlx-vlm-server", "llama-cpp-server"]:
                    max_tokens_name = "max_tokens"
                else:
                    max_tokens_name = "max_completion_tokens"

                if max_new_tokens is not None:
                    kwargs[max_tokens_name] = max_new_tokens
                elif self.model_name in self.model_group["PaddleOCR-VL"]:
                    kwargs[max_tokens_name] = 8192

                kwargs["extra_body"] = {}
                if skip_special_tokens is not None:
                    if self._genai_client.backend in (
                        "fastdeploy-server",
                        "vllm-server",
                        "sglang-server",
                        "mlx-vlm-server",
                        "llama-cpp-server",
                    ):
                        kwargs["extra_body"][
                            "skip_special_tokens"
                        ] = skip_special_tokens
                    else:
                        raise ValueError("Not supported")

                if repetition_penalty is not None:
                    kwargs["extra_body"]["repetition_penalty"] = repetition_penalty

                if min_pixels is not None:
                    if self._genai_client.backend == "vllm-server":
                        kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
                            "extra_body"
                        ].get("mm_processor_kwargs", {})
                        kwargs["extra_body"]["mm_processor_kwargs"][
                            "min_pixels"
                        ] = min_pixels
                    else:
                        warnings.warn(
                            f"{repr(self._genai_client.backend)} does not support `min_pixels`."
                        )

                if max_pixels is not None:
                    if self._genai_client.backend == "vllm-server":
                        kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
                            "extra_body"
                        ].get("mm_processor_kwargs", {})
                        kwargs["extra_body"]["mm_processor_kwargs"][
                            "max_pixels"
                        ] = max_pixels
                    else:
                        warnings.warn(
                            f"{repr(self._genai_client.backend)} does not support `max_pixels`."
                        )

                future = self._genai_client.create_chat_completion(
                    [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image_url", "image_url": {"url": image_url}},
                                {"type": "text", "text": item["query"]},
                            ],
                        }
                    ],
                    return_future=True,
                    timeout=600,
                    **kwargs,
                )

                futures.append(future)

            results = []
            for future in futures:
                result = future.result()
                results.append(result.choices[0].message.content)

            return results
        except Exception:
            # Cancel all pending futures to avoid wasting resources
            for future in futures:
                if not future.done():
                    future.cancel()
            raise
