# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING

import paddle
from paddle.jit.profiler import EventGuard, event_register

from ..infer_meta import convert_meta_to_input_spec
from ..utils import (
    ENV_SOT_EXPORT,
    Cache,
    InfoCollector,
    NewSymbolHitRateInfo,
    Singleton,
    SIRToCodeMap,
    StepInfoManager,
    SubGraphInfo,
    SubGraphRelationInfo,
    log_do,
)
from .export import export
from .interpreter import compile_sir

if TYPE_CHECKING:
    from paddle.static import InputSpec, Program

    from .builder import StatementIRBuilder
    from .statement_ir import ParametersHolder


def trace_back_frames():
    frame = inspect.currentframe()
    while frame.f_back is not None:
        frame = frame.f_back
        code = frame.f_code
        paddle.framework.core.sot_set_with_graph(code)


def clear_eager_tensor_name(output_tensors):
    for output_tensor in output_tensors:
        output_tensor.name = ""


def _is_builtin_op(op):
    dialect_name, opname = op.name().split(".")
    return dialect_name == "builtin"


def _is_computation_op(op):
    return not _is_builtin_op(op) and op.name() not in ["pd_op.data"]


class UniqueIdGenerator:
    def __init__(self):
        self._id = 0

    def generate(self):
        self._id += 1
        return self._id

    def __call__(self):
        return self.generate()


class TensorIdAllocator(metaclass=Singleton):
    TENSOR_ID_ATTR = "__tensor_id__"

    def __init__(self):
        self._id_generator = UniqueIdGenerator()

    def allocate(self, tensor):
        if not hasattr(tensor, self.TENSOR_ID_ATTR):
            setattr(tensor, self.TENSOR_ID_ATTR, self._id_generator())
        return getattr(tensor, self.TENSOR_ID_ATTR)


class FallbackWrapper:
    """
    Used to store and call static graph methods generated by paddle.jit.to_static
    """

    def __init__(self, compiled_fn, SIR, is_training: bool):
        self.compiled_fn = compiled_fn
        self.partial_program = None
        self.concrete_program = None
        self.SIR = SIR  # for debug
        self.is_training = is_training
        self.exported = False
        self.is_first_call = True

    def graph_size(self):
        if self.partial_program is None:
            input_spec = convert_meta_to_input_spec(
                tuple(
                    self.SIR.symbol_meta_map[symbol]
                    for symbol in self.SIR.inputs
                )
            )
            (
                self.concrete_program,
                self.partial_program,
            ) = self.compiled_fn.get_concrete_program(input_spec)
            self.partial_program.training = self.is_training
        global_block_ops = self.concrete_program.main_program.global_block().ops
        non_builtin_ops = list(filter(_is_computation_op, global_block_ops))
        return len(non_builtin_ops)

    def collect_new_symbol_hit_rate(self, inputs, outputs):
        if not InfoCollector().need_collect(NewSymbolHitRateInfo):
            return
        input_tensor_ids = []
        output_tensor_ids = []
        assert len(inputs) == 1
        assert isinstance(inputs[0], tuple)
        for i, arg in enumerate(inputs[0]):
            assert isinstance(arg, paddle.Tensor), f"Expect Tensor, got {arg}"
            tensor_id = TensorIdAllocator().allocate(arg)
            input_tensor_ids.append(tensor_id)

        for i, out in enumerate(outputs):
            assert isinstance(out, paddle.Tensor)
            tensor_id = TensorIdAllocator().allocate(out)
            output_tensor_ids.append(tensor_id)

        InfoCollector().attach(
            NewSymbolHitRateInfo, input_tensor_ids, output_tensor_ids
        )

    def collect_subgraph_relation(self, inputs, outputs, partial_program_layer):
        if not InfoCollector().need_collect(SubGraphRelationInfo):
            return
        input_shape_infos = []
        output_shape_infos = []
        forward_input_values = partial_program_layer.program.program_attr['fx']
        forward_output_values = partial_program_layer.program.program_attr['fo']
        assert len(inputs) == 1
        assert isinstance(inputs[0], tuple)
        assert len(inputs[0]) == len(forward_input_values)
        assert len(outputs) == len(forward_output_values)
        for i, arg in enumerate(inputs[0]):
            assert isinstance(arg, paddle.Tensor), f"Expect Tensor, got {arg}"
            tensor_id = TensorIdAllocator().allocate(arg)
            input_ir_shape = forward_input_values[i].shape
            input_real_shape = arg.shape
            input_shape_info = SubGraphRelationInfo.ConcreteShapeInfo(
                tensor_id, input_ir_shape, input_real_shape
            )
            input_shape_infos.append(input_shape_info)

        for i, out in enumerate(outputs):
            assert isinstance(out, paddle.Tensor)
            tensor_id = TensorIdAllocator().allocate(out)
            output_ir_shape = forward_output_values[
                partial_program_layer._outputs.quick_index_map[i]
            ].shape
            output_real_shape = out.shape
            output_shape_info = SubGraphRelationInfo.ConcreteShapeInfo(
                tensor_id, output_ir_shape, output_real_shape
            )
            output_shape_infos.append(output_shape_info)

        InfoCollector().attach(
            SubGraphRelationInfo,
            self.SIR.name,
            input_shape_infos,
            output_shape_infos,
            self.is_first_call,
            self.graph_size(),
        )

    def collect_subgraph_info(self, program: Program):
        if not InfoCollector().need_collect(SubGraphInfo):
            return

        InfoCollector().attach(
            SubGraphInfo,
            str(program),
            self.graph_size(),
            self.SIR.name,
        )

    def update_compile_time_info(self, SIR, partial_program_layer):
        if not self.is_first_call:
            return
        from ..opcode_translator.executor.executor_cache import (
            OpcodeExecutorCache,
        )

        code = SIRToCodeMap().get(SIR)
        assert code is not None, f"Cannot find code for SIR: {SIR}"

        OpcodeExecutorCache().compile_time_stats.setdefault(code, 0)
        OpcodeExecutorCache().compile_time_stats[code] += (
            partial_program_layer._compile_time_counter.get_total_time()
        )

    @event_register(
        lambda self, *args, **kwargs: f"FallbackWrapper: {self.SIR.name}"
    )
    def __call__(self, *args, **kwargs):
        if StepInfoManager().need_back_trace:
            trace_back_frames()

        log_do(
            2,
            lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR),
        )
        log_do(
            4,
            lambda: print(
                self.compiled_fn.get_concrete_program(*args, **kwargs)[
                    1
                ].train_program
            ),
        )
        if self.partial_program is None:
            with EventGuard("FallbackWrapper: get_concrete_program"):
                (
                    self.concrete_program,
                    self.partial_program,
                ) = self.compiled_fn.get_concrete_program(*args, **kwargs)
                self.partial_program.training = self.is_training
        outputs = self.partial_program.sot_call(*args, **kwargs)

        clear_eager_tensor_name(outputs)
        log_do(
            4,
            lambda: print("[CompileCache] run sir forward success."),
        )
        self.collect_new_symbol_hit_rate(args, outputs)
        self.collect_subgraph_relation(args, outputs, self.partial_program)
        self.collect_subgraph_info(self.concrete_program.main_program)
        self.update_compile_time_info(self.SIR, self.partial_program)
        if ENV_SOT_EXPORT.get() != "" and not self.exported:
            export(self.SIR, ENV_SOT_EXPORT.get())
            self.exported = True

        self.is_first_call = False
        return outputs


class CompileSIRCache(Cache, metaclass=Singleton):
    """
    Cache the compiled function of SIR
    """

    def __init__(self):
        super().__init__(weak=False)

    def key_fn(
        self,
        builder: StatementIRBuilder,
        sir_name: str,
        parameters_holder: ParametersHolder,
        input_spec: tuple[InputSpec | None, ...],
        **kwargs,
    ):
        """
        generate a hash key for a SIR

        Args:
            context: The context to compile
            sir_name: The name of the sir to compile
            build_strategy: The build strategy to compile

        Returns:
            The hash key of the SIR
        """
        sir = builder.get_sir(sir_name)
        # NOTE(dev): Is str(sir) a heavy operation ?
        hash_key = hash(
            (str(sir), *input_spec, id(parameters_holder), kwargs['training'])
        )
        return hash_key

    def value_fn(
        self,
        builder: StatementIRBuilder,
        sir_name: str,
        parameters_holder: ParametersHolder,
        input_spec: tuple[InputSpec | None, ...],
        **kwargs,
    ):
        """
        Generate static graph function

        Args:
            context: The context to compile
            sir_name: The name of the sir to compile
            build_strategy: The build strategy to compile

        Returns:
            The static graph function
        """
        build_strategy = kwargs.get("build_strategy", None)
        backend = kwargs.get("backend", None)
        return FallbackWrapper(
            paddle.jit.to_static(
                compile_sir(builder, sir_name, parameters_holder),
                input_spec=[input_spec],
                build_strategy=build_strategy,
                backend=backend,
                full_graph=True,
            ),
            builder.get_sir(sir_name),
            is_training=kwargs['training'],
        )
