# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import importlib
from pathlib import Path
from typing import Any, Dict, Optional, Type

from packaging.version import Version
from pydantic import BaseModel

from ....utils import logging
from ...utils.official_models import official_models
from ..utils import check_backend, model_name_to_module_name

NETWORK_CLASS_GETTER_KEY = "get_network_class"
PROCESSOR_CLASS_GETTER_KEY = "get_processor_class"
CONFIG_GETTER_KEY = "get_config"
CHAT_TEMPLATE_PATH_GETTER_KEY = "get_chat_template_path"
DEFAULT_CHAT_TEMPLATE_FILENAME = "chat_template.jinja"

ALL_MODEL_INFO = {
    "PaddleOCR-VL-0.9B": {
        "min_vllm_version": "0.11.1",
        "min_sglang_version": "0.5.7",
    },
    "PaddleOCR-VL-1.5-0.9B": {
        "min_vllm_version": "0.11.1",
        "min_sglang_version": "0.5.7",
    },
}


def _check_model_name_and_backend(model_name, backend):
    if model_name not in ALL_MODEL_INFO:
        raise ValueError(f"Unknown model: {model_name}")

    check_backend(backend)


def is_integrated_model_available(model_name, backend):
    _check_model_name_and_backend(model_name, backend)

    model_info = ALL_MODEL_INFO[model_name]

    if f"min_{backend}_version" in model_info:
        if backend == "vllm":
            import vllm

            backend_lib_version = vllm.__version__
        elif backend == "sglang":
            import sglang

            backend_lib_version = sglang.__version__
        else:
            backend_lib_version = "0.0.0"
        return Version(backend_lib_version) >= Version(
            model_info[f"min_{backend}_version"]
        )

    return False


def get_model_dir(model_name, backend):
    _check_model_name_and_backend(model_name, backend)

    try:
        model_dir = official_models[model_name]
    except Exception as e:
        raise RuntimeError(
            f"Could not prepare the official model for the {repr(model_name)} model with the {repr(backend)} backend."
        ) from e

    return str(model_dir)


def get_model_components(model_name, backend):
    def _get_component(getter_key):
        if not hasattr(model_module, getter_key):
            raise RuntimeError(f"`{model_module}` does not have `{getter_key}`")
        getter = getattr(model_module, getter_key)
        comp = getter(backend)
        return comp

    _check_model_name_and_backend(model_name, backend)

    mod_name = model_name_to_module_name(model_name)

    try:
        model_module = importlib.import_module(f".{mod_name}", package=__package__)
    except ModuleNotFoundError as e:
        raise ValueError(f"Unknown model: {model_name}") from e

    network_class = _get_component(NETWORK_CLASS_GETTER_KEY)

    if backend == "sglang":
        processor_class = _get_component(PROCESSOR_CLASS_GETTER_KEY)
    else:
        processor_class = None

    return network_class, processor_class


def get_default_config(model_name, backend):
    _check_model_name_and_backend(model_name, backend)

    mod_name = model_name_to_module_name(model_name)

    try:
        config_module = importlib.import_module(
            f"..configs.{mod_name}", package=__package__
        )
    except ModuleNotFoundError:
        logging.debug("No default configs were found for the model '%s'", model_name)
        default_config = {}
    else:
        if not hasattr(config_module, CONFIG_GETTER_KEY):
            raise RuntimeError(f"`{config_module}` does not have `{CONFIG_GETTER_KEY}`")
        config_getter = getattr(config_module, CONFIG_GETTER_KEY)
        default_config = config_getter(backend)

    return default_config


@contextlib.contextmanager
def get_chat_template_path(model_name, backend, model_dir):
    _check_model_name_and_backend(model_name, backend)

    with importlib.resources.path(
        "paddlex.inference.genai.chat_templates", f"{model_name}.jinja"
    ) as chat_template_path:
        if not chat_template_path.exists():
            default_chat_template_path = Path(model_dir, DEFAULT_CHAT_TEMPLATE_FILENAME)
            if (
                default_chat_template_path.exists()
                and default_chat_template_path.is_file()
            ):
                # TODO: Support symbolic links
                yield default_chat_template_path
            else:
                logging.debug(
                    "No chat template was found for the model '%s' with the backend '%s'",
                    model_name,
                    backend,
                )
                yield None
        else:
            yield chat_template_path
