# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import queue
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from PIL import Image

from ....utils import logging
from ....utils.deps import pipeline_requires_extra
from ...common.batch_sampler import ImageBatchSampler
from ...common.reader import ReadImage
from ...models import HPIConfig, PaddlePredictorOption
from ...utils.benchmark import benchmark
from .._parallel import AutoParallelImageSimpleInferencePipeline
from ..base import BasePipeline
from ..components import CropByBoxes
from ..layout_parsing.merge_table import merge_tables_across_pages
from ..layout_parsing.title_level import assign_levels_to_parsing_res
from ..layout_parsing.utils import construct_img_path, gather_imgs
from .result import BaseResult, PaddleOCRVLBlock, PaddleOCRVLResult
from .uilts import (
    convert_otsl_to_html,
    crop_margin,
    filter_overlap_boxes,
    merge_blocks,
    post_process_for_spotting,
    pre_process_for_spotting,
    tokenize_figure_of_table,
    truncate_repetitive_content,
    untokenize_figure_of_table,
)

IMAGE_LABELS = ["image", "header_image", "footer_image"]


@benchmark.time_methods
class _PaddleOCRVLPipeline(BasePipeline):
    """_PaddleOCRVLPipeline Pipeline"""

    def __init__(
        self,
        config: Dict,
        *,
        device: Optional[str] = None,
        engine: Optional[str] = None,
        engine_config: Optional[Dict[str, Any]] = None,
        pp_option: Optional[PaddlePredictorOption] = None,
        use_hpip: bool = False,
        hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
        initial_predictor: bool = True,
        **kwargs,
    ) -> None:
        """Initializes the PaddleOCR-VL pipeline.

        Args:
            config (Dict): Configuration dictionary containing various settings.
            device (Optional[str], optional): The device to use for prediction. Defaults to `None`.
            engine (Optional[str], optional): Inference engine. Defaults to `None`.
            engine_config (Optional[Dict[str, Any]], optional): Engine-specific config. Defaults to `None`.
            pp_option (Optional[PaddlePredictorOption], optional): Paddle predictor options.
                Defaults to `None`.
            use_hpip (bool, optional): Whether to use HPIP. Defaults to `False`.
            hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
                HPIP configuration. Defaults to `None`.
            initial_predictor (bool, optional): Whether to initialize predictors.
                Defaults to `True`.
        """
        super().__init__(
            device=device,
            engine=engine,
            engine_config=engine_config,
            pp_option=pp_option,
            use_hpip=use_hpip,
            hpi_config=hpi_config,
            **kwargs,
        )

        if initial_predictor:
            self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
            if self.use_doc_preprocessor:
                doc_preprocessor_config = config.get("SubPipelines", {}).get(
                    "DocPreprocessor",
                    {
                        "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
                    },
                )
                self.doc_preprocessor_pipeline = self.create_pipeline(
                    doc_preprocessor_config
                )

            self.use_layout_detection = config.get("use_layout_detection", True)
            if self.use_layout_detection:
                layout_det_config = config.get("SubModules", {}).get(
                    "LayoutDetection",
                    {"model_config_error": "config error for layout_det_model!"},
                )
                model_name = layout_det_config.get("model_name", None)
                assert model_name is not None and model_name in [
                    "PP-DocLayoutV2",
                    "PP-DocLayoutV3",
                ], "model_name must be PP-DocLayoutV2 or PP-DocLayoutV3"
                layout_kwargs = {}
                if (threshold := layout_det_config.get("threshold", None)) is not None:
                    layout_kwargs["threshold"] = threshold
                if (
                    layout_nms := layout_det_config.get("layout_nms", None)
                ) is not None:
                    layout_kwargs["layout_nms"] = layout_nms
                if (
                    layout_unclip_ratio := layout_det_config.get(
                        "layout_unclip_ratio", None
                    )
                ) is not None:
                    layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
                if (
                    layout_merge_bboxes_mode := layout_det_config.get(
                        "layout_merge_bboxes_mode", None
                    )
                ) is not None:
                    layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
                self.layout_det_model = self.create_model(
                    layout_det_config, **layout_kwargs
                )

            self.use_chart_recognition = config.get("use_chart_recognition", False)
            self.use_seal_recognition = config.get("use_seal_recognition", False)

            vl_rec_config = config.get("SubModules", {}).get(
                "VLRecognition",
                {"model_config_error": "config error for vl_rec_model!"},
            )

            self.vl_rec_model = self.create_model(vl_rec_config)
            self.format_block_content = config.get("format_block_content", False)
            self.use_ocr_for_image_block = config.get("use_ocr_for_image_block", False)

            self.batch_sampler = ImageBatchSampler(
                batch_size=config.get("batch_size", 1)
            )
            self.img_reader = ReadImage(format="BGR")
            self.crop_by_boxes = CropByBoxes()

            self.use_queues = config.get("use_queues", False)
            self.merge_layout_blocks = config.get("merge_layout_blocks", True)
            self.markdown_ignore_labels = config.get(
                "markdown_ignore_labels",
                [
                    "number",
                    "footnote",
                    "header",
                    "header_image",
                    "footer",
                    "footer_image",
                    "aside_text",
                ],
            )

            lp = config.get("layout_prep_cpu_workers", 0)
            try:
                self.layout_prep_cpu_workers = max(0, int(lp))
            except (TypeError, ValueError) as exc:
                raise ValueError(
                    "`layout_prep_cpu_workers` must be a non-negative integer "
                    "(0 disables parallel CPU layout preparation)."
                ) from exc

    def close(self):
        if hasattr(self, "vl_rec_model"):
            self.vl_rec_model.close()

    def get_model_settings(
        self,
        use_doc_orientation_classify: Union[bool, None],
        use_doc_unwarping: Union[bool, None],
        use_layout_detection: Union[bool, None],
        use_chart_recognition: Union[bool, None],
        use_seal_recognition: Union[bool, None],
        use_ocr_for_image_block: Union[bool, None],
        format_block_content: Union[bool, None],
        merge_layout_blocks: Union[bool, None],
        markdown_ignore_labels: Optional[List[str]] = None,
    ) -> dict:
        """
        Get the model settings based on the provided parameters or default values.

        Args:
            use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
            use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.

        Returns:
            dict: A dictionary containing the model settings.

        """
        if use_doc_orientation_classify is None and use_doc_unwarping is None:
            use_doc_preprocessor = self.use_doc_preprocessor
        else:
            if use_doc_orientation_classify is True or use_doc_unwarping is True:
                use_doc_preprocessor = True
            else:
                use_doc_preprocessor = False

        if use_layout_detection is None:
            use_layout_detection = self.use_layout_detection

        if use_chart_recognition is None:
            use_chart_recognition = self.use_chart_recognition

        if use_seal_recognition is None:
            use_seal_recognition = self.use_seal_recognition

        if use_ocr_for_image_block is None:
            use_ocr_for_image_block = self.use_ocr_for_image_block

        if format_block_content is None:
            format_block_content = self.format_block_content

        if merge_layout_blocks is None:
            merge_layout_blocks = self.merge_layout_blocks

        if markdown_ignore_labels is None:
            markdown_ignore_labels = self.markdown_ignore_labels

        return dict(
            use_doc_preprocessor=use_doc_preprocessor,
            use_layout_detection=use_layout_detection,
            use_chart_recognition=use_chart_recognition,
            use_seal_recognition=use_seal_recognition,
            use_ocr_for_image_block=use_ocr_for_image_block,
            format_block_content=format_block_content,
            merge_layout_blocks=merge_layout_blocks,
            markdown_ignore_labels=markdown_ignore_labels,
        )

    def check_model_settings_valid(self, input_params: dict) -> bool:
        """
        Check if the input parameters are valid based on the initialized models.

        Args:
            input_params (Dict): A dictionary containing input parameters.

        Returns:
            bool: True if all required models are initialized according to input parameters, False otherwise.
        """

        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
            logging.error(
                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized.",
            )
            return False

        return True

    @benchmark.timeit_with_options(name="paddleocr_vl_filter_overlap_boxes")
    def _paddleocr_vl_filter_overlap_boxes(self, layout_det_res, layout_shape_mode):
        return filter_overlap_boxes(layout_det_res, layout_shape_mode)

    @benchmark.timeit_with_options(name="paddleocr_vl_crop_layout_regions")
    def _paddleocr_vl_crop_layout_regions(self, image, boxes, layout_shape_mode):
        return self.crop_by_boxes(image, boxes, layout_shape_mode)

    @benchmark.timeit_with_options(name="paddleocr_vl_merge_adjacent_blocks")
    def _paddleocr_vl_merge_adjacent_blocks(self, blocks_for_img, layout_prep_cfg):
        if not layout_prep_cfg["merge_layout_blocks"]:
            return blocks_for_img
        return merge_blocks(
            blocks_for_img,
            non_merge_labels=layout_prep_cfg["image_labels"] + ["table"],
        )

    def _paddleocr_vl_collect_page_vlm_entries_core(
        self,
        page_idx,
        blocks_for_img,
        imgs_in_doc_for_img,
        layout_prep_cfg,
    ):
        page_vlm_entries = []
        page_has_spotting = False
        page_drop_figures = set()

        for j, block in enumerate(blocks_for_img):
            block_img = block["img"]
            block_label = block["label"]
            if (
                block_label not in layout_prep_cfg["image_labels"]
                and block_img is not None
            ):
                figure_token_map = {}
                text_prompt = "OCR:"
                blk_min_pixels = layout_prep_cfg["ocr_min_pixels"]
                blk_max_pixels = layout_prep_cfg["ocr_max_pixels"]
                drop_figures = []
                if block_label == "table":
                    text_prompt = "Table Recognition:"
                    block_img, figure_token_map, drop_figures = (
                        tokenize_figure_of_table(
                            block_img, block["box"], imgs_in_doc_for_img
                        )
                    )
                    blk_min_pixels = layout_prep_cfg["table_min_pixels"]
                    blk_max_pixels = layout_prep_cfg["table_max_pixels"]
                elif (
                    block_label == "chart" and layout_prep_cfg["use_chart_recognition"]
                ):
                    text_prompt = "Chart Recognition:"
                    blk_min_pixels = layout_prep_cfg["chart_min_pixels"]
                    blk_max_pixels = layout_prep_cfg["chart_max_pixels"]
                elif "formula" in block_label and block_label != "formula_number":
                    text_prompt = "Formula Recognition:"
                    crop_img = crop_margin(block_img)
                    w, h, _ = crop_img.shape
                    if w > 2 and h > 2:
                        block_img = crop_img
                    blk_min_pixels = layout_prep_cfg["formula_min_pixels"]
                    blk_max_pixels = layout_prep_cfg["formula_max_pixels"]
                elif block_label == "spotting":
                    text_prompt = "Spotting:"
                    page_has_spotting = True
                    blk_min_pixels = 112896
                    blk_max_pixels = 1605632
                    block_img = pre_process_for_spotting(block_img)
                elif block_label == "seal" and layout_prep_cfg["use_seal_recognition"]:
                    text_prompt = "Seal Recognition:"
                    blk_min_pixels = layout_prep_cfg["seal_min_pixels"]
                    blk_max_pixels = layout_prep_cfg["seal_max_pixels"]

                page_vlm_entries.append(
                    (
                        page_idx,
                        j,
                        block_img,
                        text_prompt,
                        (blk_min_pixels, blk_max_pixels),
                        figure_token_map,
                    )
                )
                page_drop_figures.update(drop_figures)

        return page_vlm_entries, page_has_spotting, page_drop_figures

    @benchmark.timeit_with_options(name="paddleocr_vl_collect_block_vlm_inputs")
    def _paddleocr_vl_collect_block_vlm_inputs(
        self,
        page_idx,
        blocks_for_img,
        imgs_in_doc_for_img,
        layout_prep_cfg,
    ):
        return self._paddleocr_vl_collect_page_vlm_entries_core(
            page_idx,
            blocks_for_img,
            imgs_in_doc_for_img,
            layout_prep_cfg,
        )

    def _paddleocr_vl_prepare_page_core(self, payload):
        """Filter → crop → merge → build VLM inputs (safe for thread pool; no nested timers)."""
        i, image, layout_det_res, imgs_in_doc_for_img, layout_prep_cfg = payload
        layout_det_res = filter_overlap_boxes(
            layout_det_res, layout_prep_cfg["layout_shape_mode"]
        )
        boxes = layout_det_res["boxes"]
        blocks_for_img = self.crop_by_boxes(
            image, boxes, layout_prep_cfg["layout_shape_mode"]
        )
        if layout_prep_cfg["merge_layout_blocks"]:
            blocks_for_img = merge_blocks(
                blocks_for_img,
                non_merge_labels=layout_prep_cfg["image_labels"] + ["table"],
            )
        (
            page_vlm_entries,
            page_has_spotting,
            page_drop_figures,
        ) = self._paddleocr_vl_collect_page_vlm_entries_core(
            i, blocks_for_img, imgs_in_doc_for_img, layout_prep_cfg
        )
        return (
            i,
            blocks_for_img,
            page_vlm_entries,
            page_has_spotting,
            page_drop_figures,
        )

    def _paddleocr_vl_prepare_page_serial_benchmarked(self, payload):
        (
            i,
            image,
            layout_det_res,
            imgs_in_doc_for_img,
            layout_prep_cfg,
        ) = payload
        layout_det_res = self._paddleocr_vl_filter_overlap_boxes(
            layout_det_res, layout_prep_cfg["layout_shape_mode"]
        )
        boxes = layout_det_res["boxes"]
        blocks_for_img = self._paddleocr_vl_crop_layout_regions(
            image, boxes, layout_prep_cfg["layout_shape_mode"]
        )
        blocks_for_img = self._paddleocr_vl_merge_adjacent_blocks(
            blocks_for_img, layout_prep_cfg
        )
        (
            page_vlm_entries,
            page_has_spotting,
            page_drop_figures,
        ) = self._paddleocr_vl_collect_block_vlm_inputs(
            i,
            blocks_for_img,
            imgs_in_doc_for_img,
            layout_prep_cfg,
        )
        return (
            i,
            blocks_for_img,
            page_vlm_entries,
            page_has_spotting,
            page_drop_figures,
        )

    @benchmark.timeit_with_options(name="paddleocr_vl_layout_prep_parallel")
    def _paddleocr_vl_layout_prep_parallel_pages(self, page_payloads, max_workers):
        with ThreadPoolExecutor(max_workers=max_workers) as cpu_pool:
            return list(
                cpu_pool.map(self._paddleocr_vl_prepare_page_core, page_payloads)
            )

    @benchmark.timeit_with_options(name="paddleocr_vl_aggregate_vlm_batches")
    def _paddleocr_vl_aggregate_vlm_batches(self, page_results):
        blocks = []
        has_spotting = False
        drop_figures_set = set()
        batch_dict_by_pixel = {}
        id2pixel_key_map = {}
        for (
            _,
            blocks_for_img,
            page_vlm_entries,
            page_has_spotting,
            page_drop_figures,
        ) in page_results:
            blocks.append(blocks_for_img)
            if page_has_spotting:
                has_spotting = True
            drop_figures_set.update(page_drop_figures)
            for (
                i,
                j,
                block_img,
                text_prompt,
                pixel_key,
                figure_token_map,
            ) in page_vlm_entries:
                if pixel_key not in batch_dict_by_pixel:
                    batch_dict_by_pixel[pixel_key] = {
                        "images": [],
                        "queries": [],
                        "figure_token_maps": [],
                        "vlm_block_ids": [],
                        "curr_vlm_block_idx": 0,
                    }
                batch_dict_by_pixel[pixel_key]["images"].append(block_img)
                batch_dict_by_pixel[pixel_key]["queries"].append(text_prompt)
                batch_dict_by_pixel[pixel_key]["figure_token_maps"].append(
                    figure_token_map
                )
                batch_dict_by_pixel[pixel_key]["vlm_block_ids"].append((i, j))
                id2pixel_key_map[(i, j)] = pixel_key

        return (
            blocks,
            has_spotting,
            drop_figures_set,
            batch_dict_by_pixel,
            id2pixel_key_map,
        )

    @benchmark.timeit_with_options(name="paddleocr_vl_run_vl_recognition_batches")
    def _paddleocr_vl_run_vl_recognition_batches(
        self, batch_dict_by_pixel, has_spotting, vlm_kwargs
    ):
        if vlm_kwargs is None:
            vlm_kwargs = {}
        elif vlm_kwargs.get("max_new_tokens", None) is None:
            vlm_kwargs["max_new_tokens"] = 4096

        for pixel_key in batch_dict_by_pixel:
            min_px, max_px = pixel_key
            kwargs = {
                "use_cache": True,
                "min_pixels": min_px,
                "max_pixels": max_px,
                **vlm_kwargs,
            }
            pv_images = batch_dict_by_pixel[pixel_key]["images"]
            queries = batch_dict_by_pixel[pixel_key]["queries"]
            batch_results = list(
                self.vl_rec_model.predict(
                    [
                        {
                            "image": image,
                            "query": query,
                        }
                        for image, query in zip(pv_images, queries)
                    ],
                    skip_special_tokens=False if has_spotting else True,
                    **kwargs,
                )
            )
            del pv_images, queries
            batch_dict_by_pixel[pixel_key]["vlm_results"] = batch_results

    @benchmark.timeit_with_options(name="paddleocr_vl_assemble_parsing_results")
    def _paddleocr_vl_assemble_parsing_results(
        self,
        blocks,
        batch_dict_by_pixel,
        id2pixel_key_map,
        drop_figures_set,
        vis_image_labels,
    ):
        parsing_res_lists = []
        table_res_lists = []
        spotting_res_list = []
        image_path_to_obj_map = {}
        table_blocks = []
        for i, blocks_for_img in enumerate(blocks):
            parsing_res_list = []
            table_res_list = []
            spotting_res = {}
            for j, block in enumerate(blocks_for_img):
                block_img = block["img"]
                block_bbox = block["box"]
                block_label = block["label"]
                block_content = ""
                figure_token_map = {}
                if (i, j) in id2pixel_key_map:
                    pixel_key = id2pixel_key_map[(i, j)]
                    pixel_info = batch_dict_by_pixel[pixel_key]
                    curr_vlm_block_idx = pixel_info["curr_vlm_block_idx"]
                    assert curr_vlm_block_idx < len(
                        pixel_info["vlm_block_ids"]
                    ) and pixel_info["vlm_block_ids"][curr_vlm_block_idx] == (i, j)
                    vl_rec_result = pixel_info["vlm_results"][curr_vlm_block_idx]
                    block_img4vl = pixel_info["images"][curr_vlm_block_idx]
                    figure_token_map = pixel_info["figure_token_maps"][
                        curr_vlm_block_idx
                    ]
                    curr_vlm_block_idx += 1
                    pixel_info["curr_vlm_block_idx"] = curr_vlm_block_idx
                    vl_rec_result["image"] = block_img4vl
                    result_str = vl_rec_result.get("result", "")
                    if result_str is None:
                        result_str = ""
                    min_count = 5000 if block_label == "table" else 50
                    result_str = truncate_repetitive_content(
                        result_str, min_count=min_count
                    )
                    if ("\\(" in result_str and "\\)" in result_str) or (
                        "\\[" in result_str and "\\]" in result_str
                    ):
                        result_str = result_str.replace("$", "")

                        result_str = (
                            result_str.replace("\\(", " $ ")
                            .replace("\\)", " $")
                            .replace("\\[\\[", "\\[")
                            .replace("\\]\\]", "\\]")
                            .replace("\\[", " $$ ")
                            .replace("\\]", " $$ ")
                        )
                        if block_label == "formula_number":
                            result_str = result_str.replace("$", "")
                    if block_label == "table":
                        html_str = convert_otsl_to_html(result_str)
                        if html_str != "":
                            result_str = html_str
                    if block_label == "spotting":
                        h, w = block_img.shape[:2]
                        result_str, spotting_res = post_process_for_spotting(
                            result_str, w, h
                        )

                    block_content = result_str
                block_info = PaddleOCRVLBlock(
                    label=block_label,
                    bbox=block_bbox,
                    content=block_content,
                    group_id=block.get("group_id", None),
                    polygon_points=block.get("polygon_points", None),
                )
                if block_label == "table":
                    table_blocks.append(
                        {
                            "figure_token_map": figure_token_map,
                            "block": block_info,
                        }
                    )
                if block_label in vis_image_labels and block_img is not None:
                    img_path = construct_img_path(block["label"], block["box"])
                    image_path_to_obj_map[img_path] = block_info
                    if img_path not in drop_figures_set:
                        import cv2

                        block_img = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
                        block_info.image = {
                            "path": img_path,
                            "img": Image.fromarray(block_img),
                        }
                    else:
                        continue

                parsing_res_list.append(block_info)
                del block_info, block_img
            for blk_info in table_blocks:
                block = blk_info["block"]
                figure_token_map = blk_info["figure_token_map"]
                block.content = untokenize_figure_of_table(
                    block.content, figure_token_map, image_path_to_obj_map
                )
            parsing_res_lists.append(parsing_res_list)
            table_res_lists.append(table_res_list)
            spotting_res_list.append(spotting_res)
            del parsing_res_list, table_res_list, spotting_res

        return parsing_res_lists, table_res_lists, spotting_res_list

    def get_layout_parsing_results(
        self,
        images,
        layout_det_results,
        imgs_in_doc,
        use_chart_recognition=False,
        use_seal_recognition=False,
        use_ocr_for_image_block=False,
        vlm_kwargs=None,
        merge_layout_blocks=True,
        layout_shape_mode="auto",
    ):
        if vlm_kwargs is None:
            vlm_kwargs = {}

        min_pixels = vlm_kwargs.pop("min_pixels", None)
        default_min_pixels = min_pixels if min_pixels is not None else 112896
        max_pixels = vlm_kwargs.pop("max_pixels", None)
        default_max_pixels = max_pixels if max_pixels is not None else 1003520

        vis_image_labels = IMAGE_LABELS + ["seal"]
        image_labels = [] if use_ocr_for_image_block else IMAGE_LABELS.copy()
        if not use_chart_recognition:
            image_labels += ["chart"]
            vis_image_labels += ["chart"]
        if not use_seal_recognition:
            image_labels += ["seal"]
        ocr_min_pixels = vlm_kwargs.pop("ocr_min_pixels", default_min_pixels)
        ocr_max_pixels = vlm_kwargs.pop("ocr_max_pixels", default_max_pixels)
        table_min_pixels = vlm_kwargs.pop("table_min_pixels", default_min_pixels)
        table_max_pixels = vlm_kwargs.pop("table_max_pixels", default_max_pixels)
        chart_min_pixels = vlm_kwargs.pop("chart_min_pixels", default_min_pixels)
        chart_max_pixels = vlm_kwargs.pop("chart_max_pixels", default_max_pixels)
        formula_min_pixels = vlm_kwargs.pop("formula_min_pixels", default_min_pixels)
        formula_max_pixels = vlm_kwargs.pop("formula_max_pixels", default_max_pixels)
        seal_min_pixels = vlm_kwargs.pop("seal_min_pixels", default_min_pixels)
        seal_max_pixels = vlm_kwargs.pop("seal_max_pixels", default_max_pixels)

        layout_prep_cfg = {
            "layout_shape_mode": layout_shape_mode,
            "merge_layout_blocks": merge_layout_blocks,
            "image_labels": image_labels,
            "use_chart_recognition": use_chart_recognition,
            "use_seal_recognition": use_seal_recognition,
            "ocr_min_pixels": ocr_min_pixels,
            "ocr_max_pixels": ocr_max_pixels,
            "table_min_pixels": table_min_pixels,
            "table_max_pixels": table_max_pixels,
            "chart_min_pixels": chart_min_pixels,
            "chart_max_pixels": chart_max_pixels,
            "formula_min_pixels": formula_min_pixels,
            "formula_max_pixels": formula_max_pixels,
            "seal_min_pixels": seal_min_pixels,
            "seal_max_pixels": seal_max_pixels,
        }

        num_pages = len(images)
        page_payloads = [
            (
                i,
                images[i],
                layout_det_results[i],
                imgs_in_doc[i],
                layout_prep_cfg,
            )
            for i in range(num_pages)
        ]

        max_workers = (
            min(self.layout_prep_cpu_workers, num_pages) if num_pages > 0 else 0
        )
        if num_pages > 1 and max_workers > 1:
            page_results = self._paddleocr_vl_layout_prep_parallel_pages(
                page_payloads, max_workers
            )
        elif num_pages > 1:
            page_results = [
                self._paddleocr_vl_prepare_page_serial_benchmarked(p)
                for p in page_payloads
            ]
        else:
            page_results = [
                self._paddleocr_vl_prepare_page_serial_benchmarked(page_payloads[0])
            ]

        (
            blocks,
            has_spotting,
            drop_figures_set,
            batch_dict_by_pixel,
            id2pixel_key_map,
        ) = self._paddleocr_vl_aggregate_vlm_batches(page_results)

        del images, layout_det_results, page_results

        self._paddleocr_vl_run_vl_recognition_batches(
            batch_dict_by_pixel, has_spotting, vlm_kwargs
        )

        (
            parsing_res_lists,
            table_res_lists,
            spotting_res_list,
        ) = self._paddleocr_vl_assemble_parsing_results(
            blocks,
            batch_dict_by_pixel,
            id2pixel_key_map,
            drop_figures_set,
            vis_image_labels,
        )

        return (
            parsing_res_lists,
            table_res_lists,
            spotting_res_list,
            imgs_in_doc,
        )

    def predict(
        self,
        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
        use_doc_orientation_classify: Union[bool, None] = False,
        use_doc_unwarping: Union[bool, None] = False,
        use_layout_detection: Union[bool, None] = None,
        use_chart_recognition: Union[bool, None] = None,
        use_seal_recognition: Union[bool, None] = None,
        use_ocr_for_image_block: Union[bool, None] = None,
        layout_threshold: Optional[Union[float, dict]] = None,
        layout_nms: Optional[bool] = None,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
        layout_merge_bboxes_mode: Optional[str] = None,
        layout_shape_mode: Optional[str] = "auto",
        use_queues: Optional[bool] = None,
        prompt_label: Optional[Union[str, None]] = None,
        format_block_content: Union[bool, None] = None,
        repetition_penalty: Optional[float] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        min_pixels: Optional[int] = None,
        max_pixels: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        merge_layout_blocks: Optional[bool] = None,
        markdown_ignore_labels: Optional[List[str]] = None,
        vlm_extra_args: Optional[dict] = None,
        **kwargs,
    ) -> PaddleOCRVLResult:
        """
        Predicts the layout parsing result for the given input.

        Args:
            input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
                                                                        numpy array of an image, or list of numpy arrays.
            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
            use_layout_detection (Optional[bool]): Whether to use layout detection. Default is None.
            use_chart_recognition (Optional[bool]): Whether to use chart recognition. Default is None.
            use_seal_recognition (Optional[bool]): Whether to use seal recognition. Default is None.
            layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
            layout_nms (Optional[bool], optional): Whether to use layout-aware NMS. Defaults to `False`.
            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
                Defaults to `None`.
                If it's a single number, then both width and height are used.
                If it's a tuple of two numbers, then they are used separately for width and height respectively.
                If it's None, then no unclipping will be performed.
            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to `None`.
            layout_shape_mode (Optional[str], optional): The mode for layout shape. Defaults to "auto", [ "rect", "quad","poly", "auto"] are supported.
            use_queues (Optional[bool], optional): Whether to use queues. Defaults to `None`.
            prompt_label (Optional[Union[str, None]], optional): The label of the prompt in ['ocr', 'formula', 'table', 'chart']. Defaults to `None`.
            format_block_content (Optional[bool]): Whether to format the block content. Default is None.
            repetition_penalty (Optional[float]): The repetition penalty parameter used for VL model sampling. Default is None.
            temperature (Optional[float]): Temperature parameter used for VL model sampling. Default is None.
            top_p (Optional[float]): Top-p parameter used for VL model sampling. Default is None.
            min_pixels (Optional[int]): The minimum number of pixels allowed when the VL model preprocesses images. Default is None.
            max_pixels (Optional[int]): The maximum number of pixels allowed when the VL model preprocesses images. Default is None.
            max_new_tokens (Optional[int]): The maximum number of new tokens. Default is None.
            merge_layout_blocks (Optional[bool]): Whether to merge layout blocks. Default is None.
            markdown_ignore_labels (Optional[list[str]]): The list of ignored markdown labels. Default is None.
            **kwargs (Any): Additional settings to extend functionality.

        Returns:
            PaddleOCRVLResult: The predicted layout parsing result.
        """
        model_settings = self.get_model_settings(
            use_doc_orientation_classify,
            use_doc_unwarping,
            use_layout_detection,
            use_chart_recognition,
            use_seal_recognition,
            use_ocr_for_image_block,
            format_block_content,
            merge_layout_blocks,
            markdown_ignore_labels,
        )

        model_settings["return_layout_polygon_points"] = (
            False if layout_shape_mode == "rect" else True
        )

        if not self.check_model_settings_valid(model_settings):
            yield {"error": "the input params for model settings are invalid!"}

        if use_queues is None:
            use_queues = self.use_queues

        if vlm_extra_args is None:
            vlm_extra_args = {}

        if not model_settings["use_layout_detection"]:
            prompt_label = prompt_label if prompt_label else "ocr"
            if prompt_label.lower() == "chart":
                model_settings["use_chart_recognition"] = True
            elif prompt_label.lower() == "seal":
                model_settings["use_seal_recognition"] = True
            assert prompt_label.lower() in [
                "ocr",
                "formula",
                "table",
                "chart",
                "spotting",
                "seal",
            ], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'."

        def _process_cv(batch_data, new_batch_size=None):
            if not new_batch_size:
                new_batch_size = len(batch_data)

            for idx in range(0, len(batch_data), new_batch_size):
                instances = batch_data.instances[idx : idx + new_batch_size]
                input_paths = batch_data.input_paths[idx : idx + new_batch_size]
                page_indexes = batch_data.page_indexes[idx : idx + new_batch_size]
                page_counts = batch_data.page_counts[idx : idx + new_batch_size]

                image_arrays = self.img_reader(instances)

                if model_settings["use_doc_preprocessor"]:
                    doc_preprocessor_results = list(
                        self.doc_preprocessor_pipeline(
                            image_arrays,
                            use_doc_orientation_classify=use_doc_orientation_classify,
                            use_doc_unwarping=use_doc_unwarping,
                        )
                    )
                else:
                    doc_preprocessor_results = [
                        {"output_img": arr} for arr in image_arrays
                    ]

                doc_preprocessor_images = [
                    item["output_img"] for item in doc_preprocessor_results
                ]
                if model_settings["use_layout_detection"]:
                    layout_det_results = list(
                        self.layout_det_model(
                            doc_preprocessor_images,
                            threshold=layout_threshold,
                            layout_nms=layout_nms,
                            layout_unclip_ratio=layout_unclip_ratio,
                            layout_merge_bboxes_mode=layout_merge_bboxes_mode,
                            layout_shape_mode=layout_shape_mode,
                            filter_overlap_boxes=False,
                        )
                    )

                    imgs_in_doc = [
                        gather_imgs(doc_pp_img, layout_det_res["boxes"])
                        for doc_pp_img, layout_det_res in zip(
                            doc_preprocessor_images, layout_det_results
                        )
                    ]
                else:
                    layout_det_results = []
                    for doc_preprocessor_image in doc_preprocessor_images:
                        layout_det_results.append(
                            {
                                "input_path": None,
                                "page_index": None,
                                "boxes": [
                                    {
                                        "cls_id": 0,
                                        "label": prompt_label.lower(),
                                        "score": 1,
                                        "coordinate": [
                                            0,
                                            0,
                                            doc_preprocessor_image.shape[1],
                                            doc_preprocessor_image.shape[0],
                                        ],
                                    }
                                ],
                            }
                        )
                    imgs_in_doc = [[] for _ in layout_det_results]

                yield input_paths, page_indexes, page_counts, doc_preprocessor_images, doc_preprocessor_results, layout_det_results, imgs_in_doc

        def _process_vlm(results_cv):
            (
                input_paths,
                page_indexes,
                page_counts,
                doc_preprocessor_images,
                doc_preprocessor_results,
                layout_det_results,
                imgs_in_doc,
            ) = results_cv

            (
                parsing_res_lists,
                table_res_lists,
                spotting_res_list,
                imgs_in_doc,
            ) = self.get_layout_parsing_results(
                images=doc_preprocessor_images,
                layout_det_results=layout_det_results,
                imgs_in_doc=imgs_in_doc,
                use_chart_recognition=model_settings["use_chart_recognition"],
                use_seal_recognition=model_settings["use_seal_recognition"],
                use_ocr_for_image_block=model_settings["use_ocr_for_image_block"],
                vlm_kwargs={
                    "repetition_penalty": repetition_penalty,
                    "temperature": temperature,
                    "top_p": top_p,
                    "min_pixels": min_pixels,
                    "max_pixels": max_pixels,
                    "max_new_tokens": max_new_tokens,
                    **vlm_extra_args,
                },
                merge_layout_blocks=model_settings["merge_layout_blocks"],
                layout_shape_mode=layout_shape_mode,
            )

            for (
                input_path,
                page_index,
                page_count,
                doc_preprocessor_image,
                doc_preprocessor_res,
                layout_det_res,
                table_res_list,
                parsing_res_list,
                spotting_res,
                imgs_in_doc_for_img,
            ) in zip(
                input_paths,
                page_indexes,
                page_counts,
                doc_preprocessor_images,
                doc_preprocessor_results,
                layout_det_results,
                table_res_lists,
                parsing_res_lists,
                spotting_res_list,
                imgs_in_doc,
            ):
                single_img_res = {
                    "input_path": input_path,
                    "page_index": page_index,
                    "page_count": page_count,
                    "width": doc_preprocessor_image.shape[1],
                    "height": doc_preprocessor_image.shape[0],
                    "doc_preprocessor_res": doc_preprocessor_res,
                    "layout_det_res": layout_det_res,
                    "table_res_list": table_res_list,
                    "parsing_res_list": parsing_res_list,
                    "spotting_res": spotting_res,
                    "imgs_in_doc": imgs_in_doc_for_img,
                    "model_settings": model_settings,
                }
                yield PaddleOCRVLResult(single_img_res)

        if use_queues:
            max_num_batches_in_process = 64
            queue_input = queue.Queue(maxsize=max_num_batches_in_process)
            queue_cv = queue.Queue(maxsize=max_num_batches_in_process)
            queue_vlm = queue.Queue(
                maxsize=self.batch_sampler.batch_size * max_num_batches_in_process
            )
            event_shutdown = threading.Event()
            event_data_loading_done = threading.Event()
            event_cv_processing_done = threading.Event()
            event_vlm_processing_done = threading.Event()

            def _worker_input(input_):
                all_batch_data = self.batch_sampler(input_)
                while not event_shutdown.is_set():
                    try:
                        batch_data = next(all_batch_data)
                    except StopIteration:
                        break
                    except Exception as e:
                        queue_input.put((False, "input", e))
                        break
                    else:
                        queue_input.put((True, batch_data))
                        del batch_data
                event_data_loading_done.set()

            def _worker_cv():
                while not event_shutdown.is_set():
                    try:
                        item = queue_input.get(timeout=0.5)
                    except queue.Empty:
                        if event_data_loading_done.is_set():
                            event_cv_processing_done.set()
                            queue_cv.put(None)  # Sentinel to wake VLM worker
                            break
                        continue
                    if not item[0]:
                        queue_cv.put(item)
                        break
                    try:
                        for results_cv in _process_cv(
                            item[1],
                            (
                                self.layout_det_model.batch_sampler.batch_size
                                if model_settings["use_layout_detection"]
                                else None
                            ),
                        ):
                            queue_cv.put((True, results_cv))
                            del results_cv
                        del item
                    except Exception as e:
                        queue_cv.put((False, "cv", e))
                        break

            def _worker_vlm():
                MAX_QUEUE_DELAY_SECS = 0.5
                MAX_NUM_BOXES = self.vl_rec_model.batch_sampler.batch_size
                cv_done = False

                while not event_shutdown.is_set():
                    results_cv_list = []
                    start_time = time.time()
                    should_break = False
                    num_boxes = 0
                    while True:
                        remaining_time = MAX_QUEUE_DELAY_SECS - (
                            time.time() - start_time
                        )
                        if remaining_time <= 0:
                            break
                        try:
                            item = queue_cv.get(timeout=remaining_time)
                        except queue.Empty:
                            break
                        if item is None:
                            # Sentinel from CV worker — no more data coming
                            cv_done = True
                            break
                        if not item[0]:
                            queue_vlm.put(item)
                            should_break = True
                            break
                        results_cv_list.append(item[1])
                        del item
                        for res in results_cv_list[-1][5]:
                            num_boxes += len(res["boxes"])
                        if num_boxes >= MAX_NUM_BOXES:
                            break
                    if should_break:
                        break
                    if not results_cv_list:
                        if cv_done or event_cv_processing_done.is_set():
                            event_vlm_processing_done.set()
                            queue_vlm.put(None)  # Sentinel to wake consumer
                            break
                        continue

                    merged_results_cv = [
                        list(chain.from_iterable(lists))
                        for lists in zip(*results_cv_list)
                    ]
                    del results_cv_list

                    try:
                        for result_vlm in _process_vlm(merged_results_cv):
                            queue_vlm.put((True, result_vlm))
                            del result_vlm
                        del merged_results_cv
                    except Exception as e:
                        queue_vlm.put((False, "vlm", e))
                        break

                    # After processing accumulated batch, check if CV is done
                    if cv_done:
                        # Drain any remaining items from queue_cv
                        while True:
                            try:
                                item = queue_cv.get_nowait()
                            except queue.Empty:
                                break
                            if item is None or not item[0]:
                                break
                            results_cv_list_final = [item[1]]
                            merged = [
                                list(chain.from_iterable(lists))
                                for lists in zip(*results_cv_list_final)
                            ]
                            try:
                                for result_vlm in _process_vlm(merged):
                                    queue_vlm.put((True, result_vlm))
                            except Exception as e:
                                queue_vlm.put((False, "vlm", e))
                                break
                        event_vlm_processing_done.set()
                        queue_vlm.put(None)  # Sentinel to wake consumer
                        break

            thread_input = threading.Thread(
                target=_worker_input, args=(input,), daemon=False
            )
            thread_input.start()
            thread_cv = threading.Thread(target=_worker_cv, daemon=False)
            thread_cv.start()
            thread_vlm = threading.Thread(target=_worker_vlm, daemon=False)
            thread_vlm.start()

        try:
            if use_queues:
                while not (event_vlm_processing_done.is_set() and queue_vlm.empty()):
                    try:
                        item = queue_vlm.get(timeout=0.5)
                    except queue.Empty:
                        if event_vlm_processing_done.is_set():
                            break
                        continue
                    if item is None:
                        # Sentinel — VLM worker is done
                        break
                    if not item[0]:
                        raise RuntimeError(
                            f"Exception from the '{item[1]}' worker: {item[2]}"
                        )
                    else:
                        yield item[1]
            else:
                for batch_data in self.batch_sampler(input):
                    results_cv_list = list(_process_cv(batch_data))
                    assert len(results_cv_list) == 1, len(results_cv_list)
                    results_cv = results_cv_list[0]
                    for res in _process_vlm(results_cv):
                        yield res
                    del res, results_cv, results_cv_list, batch_data
        finally:
            if use_queues:
                event_shutdown.set()
                thread_input.join(timeout=5)
                if thread_input.is_alive():
                    logging.warning("Input worker did not terminate in time")
                thread_cv.join(timeout=5)
                if thread_cv.is_alive():
                    logging.warning("CV worker did not terminate in time")
                thread_vlm.join(timeout=5)
                if thread_vlm.is_alive():
                    logging.warning("VLM worker did not terminate in time")

    def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
        """
        Concatenate Markdown content from multiple pages into a single document.

        Args:
            markdown_list (list): A list containing Markdown data for each page.

        Returns:
            tuple: A tuple containing the processed Markdown text.
        """
        markdown_texts = ""

        for res in markdown_list:
            markdown_texts += "\n\n" + res["markdown_texts"]

        return markdown_texts

    def concatenate_pages(
        self,
        res_list: list,
        merge_table: bool = True,
        title_level: bool = True,
        merge_pages: bool = False,
    ):
        """Concatenate layout parsing results from multiple pages.

        Args:
            res_list: List of page parsing results
            merge_talble: Whether to merge tables across pages
            title_level: Whether to assign title levels
            merge_pages: Whether to concatenate pages using the new consolidate_pages() logic

        Returns:
            PaddleOCRVLResult: Combined OCR-VL result after merge_table or title_level policy
        """
        logging.warning(
            f"DeprecationWarning: `concatenate_pages()` is deprecated as of v3.3.14 and will be removed in v3.4.0. Please use `restructure_pages()` instead. It provides better support for table merging and title restructuring."
        )
        return self.restructure_pages(res_list, merge_table, title_level, merge_pages)

    def restructure_pages(
        self,
        res_list: list,
        merge_tables: bool = True,
        relevel_titles: bool = True,
        concatenate_pages: bool = False,
    ):
        """Restructure layout parsing results from multiple pages.
        Args:
            res_list: List of page parsing results
            merge_tables: Whether to merge tables across pages
            relevel_titles: Whether to relevel titles
            concatenate_pages: Whether to concatenate pages to a single document

        Returns:
            PaddleOCRVLResult: Combined OCR-VL result after merge_tables or relevel_titles policy
        """

        if len(res_list) == 0:
            return []

        def _get_img_obj(block, model_settings):
            if block.get("image", None):
                return block["image"]
            if block["block_label"] in ("image", "seal") or (
                block["block_label"] == "chart"
                and not model_settings.get("use_chart_recognition", False)
            ):
                path = construct_img_path(block["block_label"], block["block_bbox"])
                return {"path": path, "img": None}
            return None

        def _conver_blocks_to_obj(blocks, model_settings):
            res = []
            for block in blocks:
                obj = PaddleOCRVLBlock(
                    label=block["block_label"],
                    bbox=block["block_bbox"],
                    polygon_points=block.get("block_polygon_points", None),
                    content=re.sub(r"^#+\s", "", block["block_content"]),
                    group_id=block.get("group_id", None),
                )
                if img := _get_img_obj(block, model_settings):
                    obj.image = img
                res.append(obj)
            return res

        global_block_id = 0
        obj_res_list = []
        for one_page_res in res_list:
            if not isinstance(one_page_res, BaseResult):
                one_page_res = one_page_res["res"]
                one_page_res["imgs_in_doc"] = []
                blocks = one_page_res.get("parsing_res_list", [])
                model_settings = one_page_res.get("model_settings", {})
                blocks = _conver_blocks_to_obj(blocks, model_settings)
            else:
                blocks = one_page_res["parsing_res_list"]
                model_settings = one_page_res.get("model_settings", {})
            parsing_res_list = []
            for block in blocks:
                block.global_block_id = global_block_id
                block.global_group_id = global_block_id
                global_block_id += 1
                parsing_res_list.append(block)

            one_page_res["parsing_res_list"] = parsing_res_list
            obj_res_list.append(one_page_res)
        res_list = obj_res_list

        blocks_by_page = [res["parsing_res_list"] for res in res_list]

        if merge_tables:
            blocks_by_page = merge_tables_across_pages(blocks_by_page)
        if relevel_titles:
            blocks_by_page = assign_levels_to_parsing_res(blocks_by_page)

        concatenate_res = []
        if concatenate_pages:
            all_imgs_in_doc = []
            for res in res_list:
                all_imgs_in_doc.extend(res.get("imgs_in_doc", []))
            res_list[0]["imgs_in_doc"] = all_imgs_in_doc
            all_page_res = res_list[0]
            all_blocks = []
            for page_idx, blks in enumerate(blocks_by_page):
                for blk in blks:
                    blk.page_index = page_idx
                    all_blocks.append(blk)
            all_page_res["parsing_res_list"] = all_blocks
            all_page_res["page_index"] = None
            all_page_res["page_count"] = len(res_list)
            if model_settings["use_layout_detection"]:
                all_page_res["layout_det_res"] = [
                    res["layout_det_res"] for res in res_list
                ]
            if model_settings["use_doc_preprocessor"]:
                all_page_res["doc_preprocessor_res"] = [
                    res["doc_preprocessor_res"] for res in res_list
                ]
            concatenate_res.append(PaddleOCRVLResult(all_page_res))
        else:
            for page_idx, one_page_res in enumerate(res_list):
                one_page_res["parsing_res_list"] = blocks_by_page[page_idx]
                concatenate_res.append(PaddleOCRVLResult(one_page_res))

        yield from concatenate_res


class _BasePaddleOCRVLPipeline(AutoParallelImageSimpleInferencePipeline):
    @property
    def _pipeline_cls(self):
        return _PaddleOCRVLPipeline

    def _get_batch_size(self, config):
        return config.get("batch_size", 1)


@pipeline_requires_extra("ocr")
class PaddleOCRVLPipeline(_BasePaddleOCRVLPipeline):
    entities = "PaddleOCR-VL"


@pipeline_requires_extra("ocr")
class PaddleOCRVL15Pipeline(_BasePaddleOCRVLPipeline):
    entities = "PaddleOCR-VL-1.5"


@pipeline_requires_extra("ocr")
class PaddleOCRVL16Pipeline(_BasePaddleOCRVLPipeline):
    entities = "PaddleOCR-VL-1.6"
