# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# Apache 2.0 License for more details.

from rapidfuzz.distance import Levenshtein
from apted import APTED, Config
from apted.helpers import Tree
from collections import deque
from .parallel import parallel_process
from tqdm import tqdm
from paddle.utils import try_import


class TableTree(Tree):
    def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
        self.tag = tag
        self.colspan = colspan
        self.rowspan = rowspan
        self.content = content
        self.children = list(children)

    def bracket(self):
        """Show tree using brackets notation"""
        if self.tag == "td":
            result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
                self.tag,
                self.colspan,
                self.rowspan,
                self.content,
            )
        else:
            result = '"tag": %s' % self.tag
        for child in self.children:
            result += child.bracket()
        return "{{{}}}".format(result)


class CustomConfig(Config):
    def rename(self, node1, node2):
        """Compares attributes of trees"""
        # print(node1.tag)
        if (
            (node1.tag != node2.tag)
            or (node1.colspan != node2.colspan)
            or (node1.rowspan != node2.rowspan)
        ):
            return 1.0
        if node1.tag == "td":
            if node1.content or node2.content:
                # print(node1.content, )
                return Levenshtein.normalized_distance(node1.content, node2.content)
        return 0.0


class CustomConfig_del_short(Config):
    def rename(self, node1, node2):
        """Compares attributes of trees"""
        if (
            (node1.tag != node2.tag)
            or (node1.colspan != node2.colspan)
            or (node1.rowspan != node2.rowspan)
        ):
            return 1.0
        if node1.tag == "td":
            if node1.content or node2.content:
                # print('before')
                # print(node1.content, node2.content)
                # print('after')
                node1_content = node1.content
                node2_content = node2.content
                if len(node1_content) < 3:
                    node1_content = ["####"]
                if len(node2_content) < 3:
                    node2_content = ["####"]
                return Levenshtein.normalized_distance(node1_content, node2_content)
        return 0.0


class CustomConfig_del_block(Config):
    def rename(self, node1, node2):
        """Compares attributes of trees"""
        if (
            (node1.tag != node2.tag)
            or (node1.colspan != node2.colspan)
            or (node1.rowspan != node2.rowspan)
        ):
            return 1.0
        if node1.tag == "td":
            if node1.content or node2.content:
                node1_content = node1.content
                node2_content = node2.content
                while " " in node1_content:
                    print(node1_content.index(" "))
                    node1_content.pop(node1_content.index(" "))
                while " " in node2_content:
                    print(node2_content.index(" "))
                    node2_content.pop(node2_content.index(" "))
                return Levenshtein.normalized_distance(node1_content, node2_content)
        return 0.0


class TEDS(object):
    """Tree Edit Distance basead Similarity"""

    def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
        assert isinstance(n_jobs, int) and (
            n_jobs >= 1
        ), "n_jobs must be an integer greater than 1"
        self.structure_only = structure_only
        self.n_jobs = n_jobs
        self.ignore_nodes = ignore_nodes
        self.__tokens__ = []

    def tokenize(self, node):
        """Tokenizes table cells"""
        self.__tokens__.append("<%s>" % node.tag)
        if node.text is not None:
            self.__tokens__ += list(node.text)
        for n in node.getchildren():
            self.tokenize(n)
        if node.tag != "unk":
            self.__tokens__.append("</%s>" % node.tag)
        if node.tag != "td" and node.tail is not None:
            self.__tokens__ += list(node.tail)

    def load_html_tree(self, node, parent=None):
        """Converts HTML tree to the format required by apted"""
        global __tokens__
        if node.tag == "td":
            if self.structure_only:
                cell = []
            else:
                self.__tokens__ = []
                self.tokenize(node)
                cell = self.__tokens__[1:-1].copy()
            new_node = TableTree(
                node.tag,
                int(node.attrib.get("colspan", "1")),
                int(node.attrib.get("rowspan", "1")),
                cell,
                *deque(),
            )
        else:
            new_node = TableTree(node.tag, None, None, None, *deque())
        if parent is not None:
            parent.children.append(new_node)
        if node.tag != "td":
            for n in node.getchildren():
                self.load_html_tree(n, new_node)
        if parent is None:
            return new_node

    def evaluate(self, pred, true):
        """Computes TEDS score between the prediction and the ground truth of a
        given sample
        """
        try_import("lxml")
        from lxml import etree, html

        if (not pred) or (not true):
            return 0.0
        parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
        pred = html.fromstring(pred, parser=parser)
        true = html.fromstring(true, parser=parser)
        if pred.xpath("body/table") and true.xpath("body/table"):
            pred = pred.xpath("body/table")[0]
            true = true.xpath("body/table")[0]
            if self.ignore_nodes:
                etree.strip_tags(pred, *self.ignore_nodes)
                etree.strip_tags(true, *self.ignore_nodes)
            n_nodes_pred = len(pred.xpath(".//*"))
            n_nodes_true = len(true.xpath(".//*"))
            n_nodes = max(n_nodes_pred, n_nodes_true)
            tree_pred = self.load_html_tree(pred)
            tree_true = self.load_html_tree(true)
            distance = APTED(
                tree_pred, tree_true, CustomConfig()
            ).compute_edit_distance()
            return 1.0 - (float(distance) / n_nodes)
        else:
            return 0.0

    def batch_evaluate(self, pred_json, true_json):
        """Computes TEDS score between the prediction and the ground truth of
        a batch of samples
        @params pred_json: {'FILENAME': 'HTML CODE', ...}
        @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
        @output: {'FILENAME': 'TEDS SCORE', ...}
        """
        samples = true_json.keys()
        if self.n_jobs == 1:
            scores = [
                self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
                for filename in tqdm(samples)
            ]
        else:
            inputs = [
                {
                    "pred": pred_json.get(filename, ""),
                    "true": true_json[filename]["html"],
                }
                for filename in samples
            ]
            scores = parallel_process(
                inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
            )
        scores = dict(zip(samples, scores))
        return scores

    def batch_evaluate_html(self, pred_htmls, true_htmls):
        """Computes TEDS score between the prediction and the ground truth of
        a batch of samples
        """
        if self.n_jobs == 1:
            scores = [
                self.evaluate(pred_html, true_html)
                for (pred_html, true_html) in zip(pred_htmls, true_htmls)
            ]
        else:
            inputs = [
                {"pred": pred_html, "true": true_html}
                for (pred_html, true_html) in zip(pred_htmls, true_htmls)
            ]

            scores = parallel_process(
                inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
            )
        return scores


if __name__ == "__main__":
    import json
    import pprint

    with open("sample_pred.json") as fp:
        pred_json = json.load(fp)
    with open("sample_gt.json") as fp:
        true_json = json.load(fp)
    teds = TEDS(n_jobs=4)
    scores = teds.batch_evaluate(pred_json, true_json)
    pp = pprint.PrettyPrinter()
    pp.pprint(scores)
