# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import cv2
import json
import paddle
import paddle.distributed as dist

from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config
from tools.infer_kie_token_ser import SerPredictor


class ReArgsParser(ArgsParser):
    def __init__(self):
        super(ReArgsParser, self).__init__()
        self.add_argument(
            "-c_ser", "--config_ser", help="ser configuration file to use"
        )
        self.add_argument(
            "-o_ser", "--opt_ser", nargs="+", help="set ser configuration options "
        )

    def parse_args(self, argv=None):
        args = super(ReArgsParser, self).parse_args(argv)
        assert (
            args.config_ser is not None
        ), "Please specify --config_ser=ser_configure_file_path."
        args.opt_ser = self._parse_opt(args.opt_ser)
        return args


def make_input(ser_inputs, ser_results):
    entities_labels = {"HEADER": 0, "QUESTION": 1, "ANSWER": 2}
    batch_size, max_seq_len = ser_inputs[0].shape[:2]
    entities = ser_inputs[8][0]
    ser_results = ser_results[0]
    assert len(entities) == len(ser_results)

    # entities
    start = []
    end = []
    label = []
    entity_idx_dict = {}
    for i, (res, entity) in enumerate(zip(ser_results, entities)):
        if res["pred"] == "O":
            continue
        entity_idx_dict[len(start)] = i
        start.append(entity["start"])
        end.append(entity["end"])
        label.append(entities_labels[res["pred"]])

    entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
    entities[0, 0] = len(start)
    entities[1 : len(start) + 1, 0] = start
    entities[0, 1] = len(end)
    entities[1 : len(end) + 1, 1] = end
    entities[0, 2] = len(label)
    entities[1 : len(label) + 1, 2] = label

    # relations
    head = []
    tail = []
    for i in range(len(label)):
        for j in range(len(label)):
            if label[i] == 1 and label[j] == 2:
                head.append(i)
                tail.append(j)

    relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
    relations[0, 0] = len(head)
    relations[1 : len(head) + 1, 0] = head
    relations[0, 1] = len(tail)
    relations[1 : len(tail) + 1, 1] = tail

    entities = np.expand_dims(entities, axis=0)
    entities = np.repeat(entities, batch_size, axis=0)
    relations = np.expand_dims(relations, axis=0)
    relations = np.repeat(relations, batch_size, axis=0)

    # remove ocr_info segment_offset_id and label in ser input
    if isinstance(ser_inputs[0], paddle.Tensor):
        entities = paddle.to_tensor(entities)
        relations = paddle.to_tensor(relations)
    ser_inputs = ser_inputs[:5] + [entities, relations]

    entity_idx_dict_batch = []
    for b in range(batch_size):
        entity_idx_dict_batch.append(entity_idx_dict)
    return ser_inputs, entity_idx_dict_batch


class SerRePredictor(object):
    def __init__(self, config, ser_config):
        global_config = config["Global"]
        if "infer_mode" in global_config:
            ser_config["Global"]["infer_mode"] = global_config["infer_mode"]

        self.ser_engine = SerPredictor(ser_config)

        #  init re model

        # build post process
        self.post_process_class = build_post_process(
            config["PostProcess"], global_config
        )

        # build model
        self.model = build_model(config["Architecture"])

        load_model(config, self.model, model_type=config["Architecture"]["model_type"])

        self.model.eval()

    def __call__(self, data):
        ser_results, ser_inputs = self.ser_engine(data)
        re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
        if self.model.backbone.use_visual_backbone is False:
            re_input.pop(4)
        preds = self.model(re_input)
        post_result = self.post_process_class(
            preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch
        )
        return post_result


def preprocess():
    FLAGS = ReArgsParser().parse_args()
    config = load_config(FLAGS.config)
    config = merge_config(config, FLAGS.opt)

    ser_config = load_config(FLAGS.config_ser)
    ser_config = merge_config(ser_config, FLAGS.opt_ser)

    logger = get_logger()

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config["Global"]["use_gpu"]

    device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
    device = paddle.set_device(device)

    logger.info("{} re config {}".format("*" * 10, "*" * 10))
    print_dict(config, logger)
    logger.info("\n")
    logger.info("{} ser config {}".format("*" * 10, "*" * 10))
    print_dict(ser_config, logger)
    logger.info("train with paddle {} and device {}".format(paddle.__version__, device))
    return config, ser_config, device, logger


if __name__ == "__main__":
    config, ser_config, device, logger = preprocess()
    os.makedirs(config["Global"]["save_res_path"], exist_ok=True)

    ser_re_engine = SerRePredictor(config, ser_config)

    if config["Global"].get("infer_mode", None) is False:
        data_dir = config["Eval"]["dataset"]["data_dir"]
        with open(config["Global"]["infer_img"], "rb") as f:
            infer_imgs = f.readlines()
    else:
        infer_imgs = get_image_file_list(config["Global"]["infer_img"])

    with open(
        os.path.join(config["Global"]["save_res_path"], "infer_results.txt"),
        "w",
        encoding="utf-8",
    ) as fout:
        for idx, info in enumerate(infer_imgs):
            if config["Global"].get("infer_mode", None) is False:
                data_line = info.decode("utf-8")
                substr = data_line.strip("\n").split("\t")
                img_path = os.path.join(data_dir, substr[0])
                data = {"img_path": img_path, "label": substr[1]}
            else:
                img_path = info
                data = {"img_path": img_path}

            save_img_path = os.path.join(
                config["Global"]["save_res_path"],
                os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg",
            )

            result = ser_re_engine(data)
            result = result[0]
            fout.write(img_path + "\t" + json.dumps(result, ensure_ascii=False) + "\n")
            img_res = draw_re_results(img_path, result)
            cv2.imwrite(save_img_path, img_res)

            logger.info(
                "process: [{}/{}], save result to {}".format(
                    idx, len(infer_imgs), save_img_path
                )
            )
