#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from paddle.utils import deprecated

from ...base.framework import IrGraph
from ...framework import _get_paddle_place, core

OpRole = core.op_proto_and_checker_maker.OpRole


class Quant2Int8OnednnPass:
    """
    Transform a quant model IrGraph into MKL-DNN supported INT8 IrGraph.
    The pass consists of the following transformations:
        1. gather scale values from fake quantize/dequantize operators,
        2. extract FP32 inference model graph from the quant graph, i.e.
            a.  remove fake quantize/dequantize operators,
            b.  dequantize conv2d and mul's weights,
        3. optimize the FP32 graph using standard FP32 optimization fuses
            (e.g. `conv2d`+`bn` -> `conv2d`),
        4. quantize the optimized FP32 graph using standard INT8v2 quantization
            passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
    """

    def __init__(
        self,
        _ops_to_quantize,
        _op_ids_to_skip=None,
        _scope=None,
        _place=None,
        _core=None,
        _debug=False,
    ):
        self._scope = _scope
        self._place = _get_paddle_place(_place)
        self._core = _core
        self._debug = _debug
        self._fake_quantize_types = [
            'fake_quantize_moving_average_abs_max',
            'fake_quantize_range_abs_max',
        ]
        self._fake_dequantize_types = [
            'fake_dequantize_max_abs',
            'fake_channel_wise_dequantize_max_abs',
        ]
        self._fake_quantize_dequantize_types = [
            'fake_quantize_dequantize_abs_max',
            'fake_quantize_dequantize_moving_average_abs_max',
            'fake_channel_wise_quantize_dequantize_abs_max',
        ]
        self._ops_to_quantize = _ops_to_quantize
        self._op_ids_to_skip = (
            _op_ids_to_skip if _op_ids_to_skip is not None else {-1}
        )
        self._scale_immutable_ops = [
            'transpose2',
            'reshape2',
            'pool2d',
            'slice',
            'shape',
            'nearest_interp',
            'nearest_interp_v2',
            'split',
        ]
        self._scale_ops = ['scale']
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
        self._pool_ops = ['pool2d']
        self._mul_ops = ['mul']
        self._fc_ops = ['fc']
        self._relu_ops = ['relu', 'relu6']
        self._matmul_ops = ['matmul', 'matmul_v2']
        self._gru_ops = ['fusion_gru', 'multi_gru']
        self._lstm_ops = ['fusion_lstm']
        self._weight_thresholds = {}
        # Collect the Input and Output scales from Fake quant models
        self._var_quant_scales = {}
        self._max_range = {}
        self._s8_max = 127
        self._pass_idx = 0
        self._pass_group = 'int8'

    def apply(self, graph):
        assert isinstance(graph, IrGraph), (
            'graph must be the instance of IrGraph.'
        )

        self._reset_pass_idx_and_group('int8')
        graph = self._label_skip_quantized_op(graph)
        graph = self._gather_weight_thresholds_from_fake(graph)
        graph = self._gather_input_scales_from_fake(graph)
        graph = self._gather_output_scales_from_attr(graph)
        graph = self._remove_fake_ops(graph)
        graph = self._dequantize_weights(graph)
        graph = self._optimize_fp32_graph(graph)
        graph = self._compute_weight_scales(graph)
        # This function causes nondeterministic quantization behavior
        # graph = self._update_relu_output_scales(graph)
        graph = self._propagate_scales(graph)
        graph = self._quantize_fp32_graph(graph)
        graph = self._cleanup(graph)
        return graph

    def prepare_and_optimize_fp32(self, graph):
        assert isinstance(graph, IrGraph), (
            'graph must be the instance of IrGraph.'
        )

        self._reset_pass_idx_and_group('fp32')
        graph = self._optimize_fp32_graph(graph)
        graph = self._cleanup(graph)
        return graph

    def _reset_pass_idx_and_group(self, group):
        self._pass_idx = 0
        self._pass_group = group

    def _convert_scale2tensor(self, scale):
        tensor = core.DenseTensor()
        tensor.set(scale, core.CPUPlace())
        return tensor

    def _is_quantizing_all_ops(self):
        return len(self._ops_to_quantize) == 0

    def _is_any_of_op_types_in_graph(self, op_types, graph):
        return any(op.name() in op_types for op in graph.all_op_nodes())

    def _is_any_of_op_types_quantized(self, op_types, graph):
        return self._is_any_of_op_types_in_graph(op_types, graph) and (
            self._is_quantizing_all_ops()
            or any(op_type in self._ops_to_quantize for op_type in op_types)
        )

    def _is_conv_quantized(self, graph):
        return self._is_any_of_op_types_quantized(self._conv_ops, graph)

    def _is_fc_quantized(self, graph):
        return self._is_any_of_op_types_quantized(self._fc_ops, graph)

    def _label_skip_quantized_op(self, graph):
        """
        For some ops(conv2d, depthwise_conv2d, mul, matmul), find and label
        the skip quantized ops. cpu_quantize_placement_pass will use the
        label to identify it.
        For static models, the skip quantized ops have `skip_quant` attr.
        Therefore, it only needs to find and label the skip quantized ops for
        dygraph models, in which the quantized ops don't have `quantization_type`
        attr.
        """
        target_ops = self._conv_ops + self._mul_ops + self._matmul_ops
        for op_node in graph.all_op_nodes():
            if op_node.name() in target_ops and not op_node.op().has_attr(
                "quantization_type"
            ):
                is_quantized_op = True
                for var_node in op_node.inputs:
                    for front_op_node in var_node.inputs:
                        if "quantize" not in front_op_node.name():
                            is_quantized_op = False
                if not is_quantized_op:
                    op_node.op()._set_attr("skip_quant", True)
        return graph

    def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
        """
        Save quantization scales for variables. Do not overwrite.
        """
        scales = self._var_quant_scales
        for var_name in var_names:
            if var_name not in scales:
                scales[var_name] = (use_unsigned_int, lod_tensor)

    def _gather_input_scales_from_fake(self, graph):
        # fake_quantize_dequantize_abs_max doesn't have scale value
        fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
        fake_ops.extend(self._fake_quantize_types)

        for op in graph.all_op_nodes():
            if op.name() in fake_ops:
                bit_length = op.op().attr("bit_length")
                assert bit_length == 8, (
                    f'Unsupported number quantization bits ({bit_length}). Only 8 is supported now.'
                )

                input_name = op.input("X")[0]
                scale_name = op.input("InScale")[0]
                output_name = op.output("Out")[0]
                # Gather new weight scales after folding batchnorm in convolution
                scale = np.array(
                    1.0 / self._load_param(self._scope, scale_name)[0]
                ).astype(np.float64)
                scale[scale == np.inf] = 0.0
                lod_tensor = self._convert_scale2tensor(scale)
                use_unsigned_int = False
                self._add_scale_for_vars(
                    [input_name, output_name], use_unsigned_int, lod_tensor
                )

        return graph

    def _gather_weight_thresholds_from_fake(self, graph):
        for op in graph.all_op_nodes():
            if op.name() in self._fake_dequantize_types:
                input_name = op.input("X")[0]
                if op.op().has_attr("max_range"):
                    _max_range = np.array(op.op().attr("max_range")).astype(
                        np.float64
                    )
                    self._weight_thresholds[input_name] = np.array(
                        self._s8_max * self._s8_max / _max_range
                    ).astype(np.float64)
                else:
                    scale_name = op.input("Scales")[0]
                    self._weight_thresholds[input_name] = np.array(
                        self._load_param(self._scope, scale_name)
                    ).astype(np.float64)

        return graph

    def _gather_output_scales_from_attr(self, graph):
        for op in graph.all_op_nodes():
            if op.op().has_attr("out_threshold"):
                attr_scale = op.op().attr("out_threshold")
                if attr_scale == 0.0:
                    continue
                scale = np.array(1.0 / attr_scale).astype(np.float64)
                scale[scale == np.inf] = 0.0
                scale_lod_tensor = self._convert_scale2tensor(scale)
                use_unsigned_int = False
                for output_name in op.op().outputs():
                    for out_var_name in op.op().output(output_name):
                        self._add_scale_for_vars(
                            [out_var_name], use_unsigned_int, scale_lod_tensor
                        )

        return graph

    def _propagate_scales(self, graph):
        def _update_scale_op_in_scale(op, input, output):
            unsigned, tensor = self._var_quant_scales[output]
            scale = np.array(tensor) * op.op().attr("scale")
            new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
            self._var_quant_scales[input] = (unsigned, new_tensor)

        def _update_scales(graph):
            waiting_for_scale = set()
            for op in graph.all_op_nodes():
                if op.name() in self._scale_immutable_ops:
                    if op.name() == 'slice' or op.name() == 'shape':
                        input_name = op.input("Input")[0]
                    else:
                        input_name = op.input("X")[0]
                    output_name = op.output("Out")[0]
                    tensor_names = [input_name, output_name]

                    if all(
                        name not in self._var_quant_scales
                        for name in tensor_names
                    ):
                        waiting_for_scale.update(tensor_names)
                        continue
                    elif input_name in self._var_quant_scales:
                        self._var_quant_scales[output_name] = (
                            self._var_quant_scales[input_name]
                        )
                    elif output_name in self._var_quant_scales:
                        self._var_quant_scales[input_name] = (
                            self._var_quant_scales[output_name]
                        )

                elif op.name() == 'concat':
                    output_name = op.output("Out")[0]
                    if output_name in self._var_quant_scales:
                        input_names = op.input("X")
                        for input_name in input_names:
                            self._var_quant_scales[input_name] = (
                                self._var_quant_scales[output_name]
                            )
                elif op.name() in self._scale_ops:
                    input_name = op.input("X")[0]
                    output_name = op.output("Out")[0]
                    if output_name in self._var_quant_scales:
                        _update_scale_op_in_scale(op, input_name, output_name)
            return waiting_for_scale

        waiting_for_scale = _update_scales(graph)
        waiting_for_scale_prev = set()

        while (
            len(waiting_for_scale) != 0
            and waiting_for_scale != waiting_for_scale_prev
        ):
            waiting_for_scale_prev = waiting_for_scale
            waiting_for_scale = _update_scales(graph)

        return graph

    def _load_param(self, scope, param_name):
        return np.array(scope.find_var(param_name).get_tensor())

    def _remove_fake_ops(self, graph):
        for op in graph.all_op_nodes():
            if op.name() in self._fake_quantize_types:
                self._remove_fake_quantize(graph, op)
            elif op.name() in self._fake_dequantize_types:
                self._remove_fake_dequantize(graph, op)
            elif op.name() in self._fake_quantize_dequantize_types:
                self._remove_fake_dequantize(graph, op)

        return graph

    def _remove_fake_quantize(self, graph, op):
        fake_quant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
        fake_quant_in_scale = graph._find_node_by_name(
            op.inputs, op.input("InScale")[0]
        )
        fake_quant_out = graph._find_node_by_name(
            op.outputs, op.output("Out")[0]
        )
        fake_quant_out_scale = graph._find_node_by_name(
            op.outputs, op.output("OutScale")[0]
        )

        next_ops = fake_quant_out.outputs
        for next_op in next_ops:
            self._swap_inputs(next_op, fake_quant_out, fake_quant_in)
            graph.link_to(fake_quant_in, next_op)
        graph.safe_remove_nodes(
            {op, fake_quant_in_scale, fake_quant_out, fake_quant_out_scale}
        )

        return graph

    def _remove_fake_dequantize(self, graph, op):
        fake_dequant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
        fake_dequant_out = graph._find_node_by_name(
            op.outputs, op.output("Out")[0]
        )

        next_ops = fake_dequant_out.outputs
        for next_op in next_ops:
            self._swap_inputs(next_op, fake_dequant_out, fake_dequant_in)
            graph.link_to(fake_dequant_in, next_op)
        graph.safe_remove_nodes({op, fake_dequant_out})

        return graph

    def _swap_inputs(self, op, old_input, new_input):
        for input_name in op.op().input_names():
            if old_input.name() in op.input(input_name):
                op.op().set_input(
                    input_name,
                    [
                        new_input.name() if x == old_input.name() else x
                        for x in op.input(input_name)
                    ],
                )

    def _dequantize_weights(self, graph):
        def _is_int8_weights(op_node, weight_name):
            weight_var_name = op_node.input(weight_name)[0]
            if self._scope.find_var(weight_var_name) is None:
                return False
            weight = self._load_param(self._scope, weight_var_name)
            return np.all(np.mod(weight, 1) == 0)

        mul_and_matmul_ops = self._mul_ops + self._matmul_ops
        for op in graph.all_op_nodes():
            if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"):
                self._dequantize_op_weights(graph, op, "Filter", "Output")
            elif op.name() in mul_and_matmul_ops and _is_int8_weights(op, "Y"):
                self._dequantize_op_weights(graph, op, "Y", "Out")

        return graph

    def _dequantize_op_weights(self, graph, op_node, weight_name, output_name):
        weight_var_name = op_node.input(weight_name)[0]
        output_var_name = op_node.output(output_name)[0]
        # Convert int8 range weights to fp32 range weights
        scales = self._weight_thresholds[output_var_name]
        weight = self._load_param(self._scope, weight_var_name)
        if scales.size == 1 or scales.size == weight.shape[0]:
            w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
        elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
            w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales)
        else:
            raise ValueError(
                f"The size of weight scales vector ({scales.size}) does not match the dimensions ({weight.shape}) of the weights tensor {weight_var_name}."
            )
        w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
        self._restore_var(weight_var_name, w_fp32)

    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)

    def _update_activations(self, graph):
        for op in graph.all_op_nodes():
            if op.name() in self._conv_ops and not op.op().has_attr(
                "fuse_activation"
            ):
                activation = ""
                if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
                    activation = "relu"
                op.set_attr("fuse_activation", activation)
        return graph

    def _remove_ctrl_vars(self, graph):
        remove_ctr_vars = set()
        for node in graph.all_var_nodes():
            if node.is_ctrl_var():
                remove_ctr_vars.add(node)
        graph.safe_remove_nodes(remove_ctr_vars)
        return graph

    def _optimize_fp32_graph(self, graph):
        graph = self._update_activations(graph)
        graph = self._remove_ctrl_vars(graph)
        graph = self._apply_pass(
            graph, 'onednn_placement_pass', ['onednn_enabled_op_types'], [set()]
        )
        # remove dropout ops
        graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
        graph = self._apply_pass(graph, 'layer_norm_fuse_pass')
        graph = self._apply_pass(graph, 'attention_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass')
        graph = self._apply_pass(graph, 'fc_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'mul_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'fc_gru_fuse_pass')
        graph = self._apply_pass(graph, 'mul_gru_fuse_pass')
        graph = self._apply_pass(graph, 'multi_gru_fuse_pass')
        graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass')
        graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass')
        graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass')
        graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
        graph = self._apply_pass(graph, 'is_test_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass')
        graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
        graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
        graph = self._apply_pass(graph, 'depthwise_conv_onednn_pass')
        graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_affine_channel_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass')
        graph = self._apply_pass(
            graph, 'conv_transpose_eltwiseadd_bn_fuse_pass'
        )
        graph = self._apply_pass(graph, 'conv_bias_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_transpose_bias_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_elementwise_add_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_activation_onednn_fuse_pass')
        graph = self._apply_pass(
            graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]
        )
        graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
        if self._is_fc_quantized(graph):
            # Disabled due to topology-dependent speed-up
            graph = self._apply_pass(graph, 'fc_onednn_pass')
            graph = self._apply_pass(graph, 'fc_act_onednn_fuse_pass')
        graph = self._apply_pass(
            graph, 'matmul_transpose_reshape_onednn_fuse_pass'
        )
        graph = self._apply_pass(
            graph, 'matmul_elementwise_add_onednn_fuse_pass'
        )
        graph = self._apply_pass(graph, 'matmul_activation_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
        graph = self._apply_pass(graph, 'softplus_activation_onednn_fuse_pass')
        graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
        graph = self._apply_pass(
            graph, 'reshape_transpose_matmul_onednn_fuse_pass'
        )
        # the following pass should be the last one since it will work on all fused ops.
        graph = self._apply_pass(graph, 'runtime_context_cache_pass')
        return graph

    def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
        ir_pass = core.get_pass(pass_name)
        cpp_graph = graph.graph
        if not cpp_graph.has('__param_scope__'):
            cpp_graph.set_not_owned('__param_scope__', self._scope)
        if attrs:
            assert attr_values and len(attrs) == len(attr_values), (
                "Different number of pass attributes and their values."
            )
            for attr, value in zip(attrs, attr_values):
                ir_pass.set(attr, value)
        ir_pass.apply(cpp_graph)
        if self._debug:
            graph.draw(
                '.',
                f'{self._pass_group}_{self._pass_idx}_{pass_name}',
                graph.all_op_nodes(),
            )
        self._remove_unused_var_nodes(graph)
        self._pass_idx += 1
        return graph

    def _cleanup(self, graph):
        graph = self._remove_unused_var_nodes(graph)
        graph = self._set_op_role_forward(graph)
        return graph

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
        ops = graph.all_op_nodes()
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = set(
            filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
        )
        graph.safe_remove_nodes(all_unused_vars)
        return graph

    def _set_op_role_forward(self, graph):
        ops = graph.all_op_nodes()
        for op in ops:
            op.set_attr("op_role", OpRole.Forward)
        return graph

    def _compute_weight_scales(self, graph):
        def _compute_var_scales(ops, w_name, axis):
            for op in graph.all_op_nodes():
                if op.op().type() in ops:
                    weight_var_name = op.input(w_name)[0]
                    weights = np.array(
                        self._load_param(self._scope, weight_var_name)
                    )
                    scales = 1.0 / np.amax(
                        np.abs(weights.reshape(weights.shape[0], -1)).astype(
                            np.float64
                        ),
                        axis=axis,
                    )
                    scales[scales == np.inf] = 0.0

                    lod_tensor = self._convert_scale2tensor(scales)
                    use_unsigned_int = False
                    self._var_quant_scales[weight_var_name] = (
                        use_unsigned_int,
                        lod_tensor,
                    )

        def _compute_single_gru_weight_scales(wx_var_name, wh_var_name):
            wx = np.array(self._load_param(self._scope, wx_var_name))
            wh = np.array(self._load_param(self._scope, wh_var_name))
            OC = wh.shape[0]
            scale_ur = 1.0 / np.max(
                np.abs(
                    np.concatenate(
                        [
                            wx[:, : 2 * OC],
                            wh.flatten()[: 2 * OC * OC].reshape(OC, 2 * OC),
                        ],
                        axis=0,
                    )
                ),
                axis=0,
            )
            scale_o = 1.0 / np.max(
                np.abs(
                    np.concatenate(
                        [
                            wx[:, 2 * OC :],
                            wh.flatten()[2 * OC * OC :].reshape(OC, OC),
                        ],
                        axis=0,
                    )
                ),
                axis=0,
            )

            gru_weights_scale = np.concatenate([scale_ur, scale_o]).astype(
                'float'
            )

            return self._convert_scale2tensor(gru_weights_scale)

        def _compute_gru_weight_scales(wx_name, wh_name):
            for op in graph.all_op_nodes():
                if op.op().type() in self._gru_ops:
                    assert len(op.input(wx_name)) == len(op.input(wh_name)), (
                        f'Mismatch in number of weights inputs ({len(op.input(wx_name))} for WeightX vs. {len(op.input(wh_name))} for WeightH).'
                    )
                    for i, wx_var_name in enumerate(op.input(wx_name)):
                        wh_var_name = op.input(wh_name)[i]
                        use_unsigned_int = False
                        lod_tensor = _compute_single_gru_weight_scales(
                            wx_var_name, wh_var_name
                        )
                        self._var_quant_scales[wx_var_name] = (
                            use_unsigned_int,
                            lod_tensor,
                        )

        def _compute_single_lstm_weight_scales(wx_var_name, wh_var_name):
            wx = np.array(self._load_param(self._scope, wx_var_name))
            wh = np.array(self._load_param(self._scope, wh_var_name))

            lstm_weights_scale = 1.0 / np.max(
                np.abs(np.concatenate([wx[:, :], wh[:, :]], axis=0)), axis=0
            )
            lstm_weights_scale = lstm_weights_scale.astype('float')

            return self._convert_scale2tensor(lstm_weights_scale)

        def _compute_lstm_weight_scales(wx_name, wh_name):
            for op in graph.all_op_nodes():
                if op.op().type() in self._lstm_ops:
                    assert len(op.input(wx_name)) == len(op.input(wh_name)), (
                        f'Mismatch in number of weights inputs ({len(op.input(wx_name))} for WeightX vs. {len(op.input(wh_name))} for WeightH).'
                    )
                    for i, wx_var_name in enumerate(op.input(wx_name)):
                        wh_var_name = op.input(wh_name)[i]
                        use_unsigned_int = False
                        lod_tensor = _compute_single_lstm_weight_scales(
                            wx_var_name, wh_var_name
                        )
                        self._var_quant_scales[wx_var_name] = (
                            use_unsigned_int,
                            lod_tensor,
                        )

        _compute_var_scales(self._conv_ops, "Filter", axis=1)
        _compute_var_scales(self._fc_ops, "W", axis=0)
        _compute_var_scales(self._gru_ops, "WeightH", axis=0)
        _compute_var_scales(self._lstm_ops, "WeightH", axis=0)
        _compute_gru_weight_scales("WeightX", "WeightH")
        _compute_lstm_weight_scales("WeightX", "WeightH")
        return graph

    def _update_relu_output_scales(self, graph):
        def _set_unsigned_scale(graph, ops, op_out_name, predicate):
            '''
            Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
            predicate applied on op passes. Typically, the predicate checks if op's
            activation is set to relu.
            '''
            for op in graph.all_op_nodes():
                if op.name() in ops:
                    out_name = op.output(op_out_name)[0]
                    if out_name in self._var_quant_scales and predicate(
                        op.op()
                    ):
                        is_unsigned, tensor = self._var_quant_scales[out_name]
                        if is_unsigned is False:
                            # If the variable is signed, it means that the scales for this var
                            # were computed for signed data, so the scale must be multiplied by 2
                            # to fill the entire range of uint8
                            scale = np.array(tensor) * 2
                            tensor = self._convert_scale2tensor(
                                scale.astype(np.float64)
                            )
                        self._var_quant_scales[out_name] = (True, tensor)
            return graph

        def conv_predicate(op):
            return op.attr("fuse_activation") in self._relu_ops

        graph = _set_unsigned_scale(
            graph, self._conv_ops, "Output", conv_predicate
        )

        def fc_predicate(op):
            return op.attr("activation_type") in self._relu_ops

        graph = _set_unsigned_scale(graph, self._fc_ops, "Out", fc_predicate)

        graph = _set_unsigned_scale(
            graph, self._relu_ops, 'Out', lambda op: True
        )

        return graph

    def _get_data_layout(self, graph):
        return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW'

    def _quantize_fp32_graph(self, graph):
        graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
        graph = self._apply_pass(
            graph, 'reshape_transpose_matmul_onednn_fuse_pass'
        )
        graph = self._apply_pass(
            graph,
            'cpu_quantize_placement_pass',
            ['quantize_enabled_op_types'],
            [self._ops_to_quantize],
        )
        graph = self._apply_pass(
            graph,
            'cpu_quantize_pass',
            ['quant_var_scales', 'data_layout'],
            [self._var_quant_scales, self._get_data_layout(graph)],
        )
        graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
        graph = self._apply_pass(graph, 'int8_scale_calculation_onednn_pass')
        graph = self._apply_pass(graph, 'params_quantization_onednn_pass')
        return graph


class Quant2Int8MkldnnPass(Quant2Int8OnednnPass):
    @deprecated(
        since="3.1.0",
        update_to="paddle.static.quantization.Quant2Int8OnednnPass",
        level=1,
        reason="Quant2Int8MkldnnPass will be removed in future",
    )
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
