# Copyright (C) 2021-2026, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py

import json
import logging
import subprocess
import tempfile
import textwrap
from pathlib import Path
from typing import Any

import torch
from huggingface_hub import (
    HfApi,
    get_token,
    hf_hub_download,
    login,
)

from doctr import models

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]


AVAILABLE_ARCHS = {
    "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
    "detection": models.detection.zoo.ARCHS,
    "recognition": models.recognition.zoo.ARCHS,
}


def login_to_hub() -> None:  # pragma: no cover
    """Login to huggingface hub"""
    access_token = get_token()
    if access_token is not None:
        logging.info("Huggingface Hub token found and valid")
        login(token=access_token)
    else:
        login()
    # check if git lfs is installed
    try:
        subprocess.call(["git", "lfs", "version"])
    except FileNotFoundError:
        raise OSError(
            "Looks like you do not have git-lfs installed, please install. \
                      You can install from https://git-lfs.github.com/. \
                      Then run `git lfs install` (you only have to do this once)."
        )


def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None:
    """Save model and config to disk for pushing to huggingface hub

    Args:
        model: PyTorch model to be saved
        save_dir: directory to save model and config
        arch: architecture name
        task: task name
    """
    save_directory = Path(save_dir)
    weights_path = save_directory / "pytorch_model.bin"
    torch.save(model.state_dict(), weights_path)

    config_path = save_directory / "config.json"

    # add model configuration
    model_config = model.cfg
    model_config["arch"] = arch
    model_config["task"] = task

    with config_path.open("w") as f:
        json.dump(model_config, f, indent=2, ensure_ascii=False)


def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None:  # pragma: no cover
    """Save model and its configuration on HF hub

    >>> from doctr.models import login_to_hub, push_to_hf_hub
    >>> from doctr.models.recognition import crnn_mobilenet_v3_small
    >>> login_to_hub()
    >>> model = crnn_mobilenet_v3_small(pretrained=True)
    >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')

    Args:
        model: PyTorch model to be saved
        model_name: name of the model which is also the repository name
        task: task name
        **kwargs: keyword arguments for push_to_hf_hub
    """
    run_config = kwargs.get("run_config", None)
    arch = kwargs.get("arch", None)

    if run_config is None and arch is None:
        raise ValueError("run_config or arch must be specified")
    if task not in ["classification", "detection", "recognition"]:
        raise ValueError("task must be one of classification, detection, recognition")

    # default readme
    readme = textwrap.dedent(
        f"""

    language: en


    <p align="center">
    <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
    </p>

    **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**

    ## Task: {task}

    https://github.com/mindee/doctr

    ### Example usage:

    ```python
    >>> from doctr.io import DocumentFile
    >>> from doctr.models import ocr_predictor, from_hub

    >>> img = DocumentFile.from_images(['<image_path>'])
    >>> # Load your model from the hub
    >>> model = from_hub('mindee/my-model')

    >>> # Pass it to the predictor
    >>> # If your model is a recognition model:
    >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
    >>>                           reco_arch=model,
    >>>                           pretrained=True)

    >>> # If your model is a detection model:
    >>> predictor = ocr_predictor(det_arch=model,
    >>>                           reco_arch='crnn_mobilenet_v3_small',
    >>>                           pretrained=True)

    >>> # Get your predictions
    >>> res = predictor(img)
    ```
    """
    )

    # add run configuration to readme if available
    if run_config is not None:
        arch = run_config.arch
        readme += textwrap.dedent(
            f"""### Run Configuration
                                  \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
        )

    if arch not in AVAILABLE_ARCHS[task]:
        raise ValueError(
            f"Architecture: {arch} for task: {task} not found.\
                         \nAvailable architectures: {AVAILABLE_ARCHS}"
        )

    commit_message = f"Add {model_name} model"

    # Create repository
    api = HfApi()
    api.create_repo(model_name, token=get_token(), exist_ok=False)

    # Save model files to a temporary directory
    with tempfile.TemporaryDirectory() as tmp_dir:
        _save_model_and_config_for_hf_hub(model, tmp_dir, arch=arch, task=task)
        readme_path = Path(tmp_dir) / "README.md"
        readme_path.write_text(readme)

        # Upload all files to the hub
        api.upload_folder(
            folder_path=tmp_dir,
            repo_id=model_name,
            commit_message=commit_message,
            token=get_token(),
        )


def from_hub(repo_id: str, **kwargs: Any):
    """Instantiate & load a pretrained model from HF hub.

    >>> from doctr.models import from_hub
    >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")

    Args:
        repo_id: HuggingFace model hub repo
        kwargs: kwargs of `hf_hub_download` or `snapshot_download`

    Returns:
        Model loaded with the checkpoint
    """
    # Get the config
    with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
        cfg = json.load(f)

    arch = cfg["arch"]
    task = cfg["task"]
    cfg.pop("arch")
    cfg.pop("task")

    if task == "classification":
        model = models.classification.__dict__[arch](
            pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"]
        )
    elif task == "detection":
        model = models.detection.__dict__[arch](pretrained=False)
    elif task == "recognition":
        model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])

    # update model cfg
    model.cfg = cfg
    # load the weights
    weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
    model.from_pretrained(weights)

    return model
