#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import textwrap
from collections.abc import Sequence

from paddle.base import core
from paddle.framework import use_pir_api
from paddle.utils import gast

from .utils import ORIGIN_INFO

__all__ = []


class Location:
    """
    Location information of source code.
    """

    __slots__ = (
        "filepath",
        "lineno",
        "col_offset",
    )

    def __init__(self, filepath, lineno, col_offset=None):
        self.filepath = filepath
        self.lineno = lineno
        self.col_offset = col_offset

    def __str__(self):
        return f"location: {self.filepath}:{self.lineno}:{self.col_offset}"

    @property
    def line_location(self):
        return (self.filepath, self.lineno)


class OriginInfo:
    """
    Original information of source code.
    """

    __slots__ = (
        "location",
        "function_name",
        "source_code",
    )

    def __init__(self, location, function_name, source_code):
        self.location = location
        self.function_name = function_name
        self.source_code = source_code

    def __str__(self):
        return f"{self.location} \nsource_code: {self.source_code}  in function {self.function_name}\n  "

    def formatted_message(self):
        flag_for_origin_info = "(* user code *)"
        return f'    File "{self.location.filepath}", line {self.location.lineno}, in {self.function_name} {flag_for_origin_info}\n\t{self.source_code.lstrip()}'

    def as_frame(self):
        return (
            self.location.filepath,
            self.location.lineno,
            self.function_name,
            self.source_code.lstrip(),
        )


class OriginInfoAttacher(gast.NodeTransformer):
    """
    Attach original source information to AST node according corresponding function.
    """

    def __init__(self, root, func):
        self.root = root
        self.func = inspect.unwrap(func)
        self.filepath = inspect.getsourcefile(self.func)
        self.source_code = inspect.getsource(self.func)
        self.current_func = []

    def transform(self):
        source_lines, begin_lineno = inspect.getsourcelines(self.func)
        begin_line = source_lines[0]
        self.col_offset = len(begin_line) - len(begin_line.lstrip())
        self.source_lines = [line.strip("\n") for line in source_lines]
        self.lineno_offset = begin_lineno - 1
        self.visit(self.root)

    def visit(self, node):
        if isinstance(node, gast.FunctionDef):
            self.current_func.append(node)
        if getattr(node, "lineno", None) is not None:
            self._attach_origin_info(node)
        self.generic_visit(node)

        if isinstance(node, gast.FunctionDef):
            self.current_func.pop()
        return node

    def _attach_origin_info(self, node):
        assert isinstance(node, gast.AST)
        assert hasattr(node, "lineno")

        lineno = self._abs_lineno(node)
        col_offset = self._abs_col_offset(node)
        loc = Location(self.filepath, lineno, col_offset)
        func_name = self.current_func[-1].name
        code_line = self.source_lines[node.lineno - 1]

        origin_info = OriginInfo(loc, func_name, code_line)
        setattr(node, ORIGIN_INFO, origin_info)

    def _abs_lineno(self, node):
        return self.lineno_offset + node.lineno

    def _abs_col_offset(self, node):
        return self.col_offset + node.col_offset


global_origin_info_map = {}


def create_and_update_origin_info_map(
    transformed_node, static_func, is_global=True
):
    """
    Creates a original information map between transformed static function and original dygraph function.

    Args:
        transformed_node(gast.AST): The AST node of transformed dygraph function with attached source information of original dygraph function.
        static_func(Callable): The static function transformed by dygraph function corresponding to transformed_node.

    Returns:
        The original information map.
    """

    origin_info_map = {}
    static_source = textwrap.dedent(inspect.getsource(static_func))
    static_node = gast.parse(static_source)
    static_node = attach_origin_info(static_node, static_func)

    for t_node, s_node in ast_walk(transformed_node, static_node):
        assert type(t_node) == type(s_node), (
            f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
        )
        dygraph_info = getattr(t_node, ORIGIN_INFO, None)
        static_info = getattr(s_node, ORIGIN_INFO, None)

        if dygraph_info is None or static_info is None:
            continue
        static_loc = static_info.location.line_location
        exist_origin_info = origin_info_map.get(static_loc)

        if exist_origin_info is not None:
            if (
                exist_origin_info.location.lineno
                >= dygraph_info.location.lineno
            ):
                continue
            if (
                exist_origin_info.location.col_offset
                <= dygraph_info.location.col_offset
            ):
                continue

        origin_info_map[static_loc] = dygraph_info

    global_origin_info_map.update(origin_info_map)
    if is_global:
        return global_origin_info_map

    return origin_info_map


def attach_origin_info(ast_node, func):
    """
    Attach original source information to AST node according corresponding function.

    Args:
        ast_node(gast.AST): The AST node to attach original source information.
        func(Callable): The corresponding function of ast_node. Parse the original information from this function.

    Returns:
        An AST node attached original source information.
    """
    resolver = OriginInfoAttacher(ast_node, func)
    resolver.transform()
    return ast_node


def ast_walk(transformed_node, static_node):
    """
    Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel.

    NOTE(liym27):
        Function ast.walk is not used because it yield all descendant nodes in no specified order.
    """

    def _as_list(x):
        if x is None:
            return []
        return list(x) if isinstance(x, Sequence) else [x]

    transformed_node_list = _as_list(transformed_node)
    static_node_list = _as_list(static_node)

    while transformed_node_list:
        assert len(transformed_node_list) == len(static_node_list)
        t_node = transformed_node_list.pop()
        s_node = static_node_list.pop()
        if type(t_node) != type(s_node):
            # NOTE(liym27):
            # Node types should be strictly required, but there is no strict distinction between gast.Load and gast.Param
            # in the ast transformation process.
            if isinstance(t_node, (gast.Load, gast.Param)) or isinstance(
                s_node, (gast.Load, gast.Param)
            ):
                continue

        assert type(t_node) == type(s_node), (
            f"The node types should be the same, but received type(t_node) is {type(t_node)}, and type(s_node) is {type(s_node)}."
        )

        yield t_node, s_node

        for field in t_node._fields:
            t_node_child = getattr(t_node, field)
            s_node_child = getattr(s_node, field)

            if isinstance(t_node_child, gast.AST):
                transformed_node_list.append(t_node_child)
                static_node_list.append(s_node_child)
            elif isinstance(t_node_child, (list, tuple)):
                assert len(t_node_child) == len(s_node_child)
                for d_item, s_item in zip(t_node_child, s_node_child):
                    if isinstance(d_item, gast.AST):
                        transformed_node_list.append(d_item)
                        static_node_list.append(s_item)


def update_op_callstack_with_origin_info(program):
    """
    Replaces op callstack information about transformed static code with original dygraph code.
    """

    def get_new_op_callstack(callstack):
        """
        An example of callstack:

            File "path1/to/file.py", line 10, in func_1
                y = paddle.tensor.fill_constant(x, shape=[1], dtype="int32")
            File "path2/to/file.py", line 740, in fill_constant
                stop_gradient=True)
            File "path3/to/file.py", line 43, in append_op
              return self.main_program.current_block().append_op(*args, **kwargs)
            File "path4/to/file.py", line 2811, in append_op
              attrs=kwargs.get("attrs", None))
            File "path5/to/file.py", line 1919, in __init__
              for frame in traceback.extract_stack():
        """

        assert len(callstack) % 2 == 0
        for i in range(0, len(callstack), 2):
            file_line = callstack[i].lstrip(" ").split(",")

            filepath = file_line[0][6:-1]
            lineno = int(file_line[1][6:])
            funcname = file_line[2][4:]
            code = callstack[i + 1].lstrip(" ")

            loc = Location(filepath, lineno)
            dygraph_func_info = global_origin_info_map.get(loc.line_location)
            if dygraph_func_info:
                filepath, lineno, funcname, code = dygraph_func_info.as_frame()

            callstack[i] = f'  File "{filepath}", line {lineno}, in {funcname}'
            callstack[i + 1] = f'    {code}'

        return callstack

    def get_all_pir_block_ops(block):
        ops = []
        for op in block.ops:
            ops.append(op)
            for sub_block in op.blocks():
                ops += get_all_pir_block_ops(sub_block)
        return ops

    op_maker = core.op_proto_and_checker_maker
    callstack_var_name = op_maker.kOpCreationCallstackAttrName()

    if use_pir_api():
        global_block = program.global_block()
        ops = get_all_pir_block_ops(global_block)
        for op in ops:
            if op.has_attr(callstack_var_name):
                op.callstack = get_new_op_callstack(op.callstack)
    else:
        for block in program.blocks:
            for i, op in enumerate(block.ops):
                if op.has_attr(callstack_var_name):
                    callstack = op.attr(callstack_var_name)

                    callstack = get_new_op_callstack(callstack)

                    try:
                        # (@xiongkun) In 2-order derivative for paddle science, there may exists `pow_grad`
                        # which has op_proto == nullptr and causes _set_attr failed. so we add a try...except.
                        op._set_attr(callstack_var_name, callstack)
                    except:
                        pass
    return program
