# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import os
import warnings
from functools import reduce

import paddle
from paddle import framework
from paddle.distributed.transpiler.details.program_utils import delete_ops
from paddle.framework import core
from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
    _get_lr_ops,
    _get_optimize_ops,
    get_sparse_tablenames,
)
from paddle.incubate.distributed.fleet.parameter_server.mode import (
    DistributedMode,
)

OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()

SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
    "lookup_table_grad": "W",
    "lookup_table_v2_grad": "W",
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'


def delete_optimizer_pass(program, config):
    def _delete_optimizer_op_and_vars(_program, optimize_ops):
        optimize_vars = []
        optimize_op_role_vars = []
        optimize_need_delete_vars = []

        for op in optimize_ops:
            optimize_vars.extend(op.input_arg_names)
            optimize_op_role_vars.extend(op.attr("op_role_var"))

        optimize_vars = list(set(optimize_vars))
        optimize_op_role_vars = list(set(optimize_op_role_vars))

        for var in optimize_vars:
            if var not in optimize_op_role_vars:
                optimize_need_delete_vars.append(var)
        need_delete_optimize_vars = list(set(optimize_need_delete_vars))

        delete_ops(_program.global_block(), optimize_ops)
        for var in need_delete_optimize_vars:
            if _program.global_block().has_var(var):
                _program.global_block()._remove_var(var)

    def _add_lr_var(main_program, compiled_config):
        # Todo: hard code for pe
        lr_var = compiled_config.origin_main_program.global_block().vars[
            "learning_rate_0"
        ]
        main_program.global_block().create_var(
            name=lr_var.name,
            shape=lr_var.shape,
            dtype=lr_var.dtype,
            type=lr_var.type,
            lod_level=lr_var.lod_level,
            persistable=True,
        )

    optimizer_ops = _get_optimize_ops(program)
    lr_ops = _get_lr_ops(program)
    optimizer_ops.extend(lr_ops)
    _delete_optimizer_op_and_vars(program, optimizer_ops)

    if hasattr(config.origin_main_program, 'lr_scheduler'):
        _add_lr_var(program, config)

    return program


def distributed_ops_pass(program, config, use_ps_gpu=False):
    trainer_id = config.get_role_id()
    send_ctx = config.get_the_one_send_context(
        split_dense_table=config.is_heter_ps_mode
    )
    w_2_table_id = {}
    emb_size = {}

    def _get_pull_sparse_ops(_program):
        pull_sparse_ops = {}
        pull_sparse_ids = {}
        push_sparse_ops = {}
        ops = {}
        for op in _program.global_block().ops:
            if (
                op.type in SPARSE_OP_TYPE_DICT.keys()
                and op.attr('remote_prefetch') is True
            ):
                param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
                if config.is_heter_ps_mode:
                    # trick for matchnet, need to modify
                    param_name += op.input("Ids")[0][0]
                ops = pull_sparse_ops.get(param_name, [])
                ops.append(op)
                pull_sparse_ops[param_name] = ops
                ids = pull_sparse_ids.get(param_name, [])
                ids.append(op.input("Ids")[0])
                pull_sparse_ids[param_name] = ids
        for op in _program.global_block().ops:
            if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys():
                param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0]
                if (
                    param_name in pull_sparse_ids
                    and op.input("Ids")[0] in pull_sparse_ids[param_name]
                ):
                    ops = push_sparse_ops.get(param_name, [])
                    ops.append(op)
                    push_sparse_ops[param_name] = ops
        return pull_sparse_ops, push_sparse_ops

    def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu):
        def dag_check_up_and_reorder(program, inputs, outputs):
            global_block = program.global_block()
            min_output_index = len(global_block.ops)
            max_input_index = -1
            input_indexes = [0] * len(global_block.ops)
            output_indexes = [0] * len(global_block.ops)
            for idx, op in enumerate(global_block.ops):
                for i in range(0, len(op.output_names)):
                    if input_indexes[idx] == 1:
                        break
                    outs = op.output(op.output_names[i])
                    for in_id, in_var in enumerate(inputs):
                        if in_var.name in outs:
                            input_indexes[idx] = 1
                            max_input_index = max(max_input_index, idx)
                            break

                for i in range(0, len(op.input_names)):
                    if output_indexes[idx] == 1:
                        break
                    ins = op.input(op.input_names[i])
                    for out_id, out_var in enumerate(outputs):
                        if out_var.name in ins:
                            output_indexes[idx] = 1
                            min_output_index = min(min_output_index, idx)

            for i in range(len(global_block.ops)):
                if input_indexes[i] == 1 and output_indexes[i] == 1:
                    warnings.warn(
                        "unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
                    )
                    return

            if min_output_index < max_input_index:
                move_ops = []
                for i in range(min_output_index + 1, len(input_indexes)):
                    if input_indexes[i] == 1:
                        move_ops.append((global_block.ops[i], i))
                for i, op in enumerate(move_ops):
                    queue = []
                    visited = set()
                    queue.append(op[1])
                    visited.add(op[0])
                    start = 0
                    while start < len(queue):
                        pos = queue[start]
                        op = global_block.ops[pos]
                        op_inputs = []
                        for k in range(0, len(op.input_names)):
                            ins = op.input(op.input_names[k])
                            op_inputs.append(ins)
                        for j in range(pos - 1, min_output_index - 1, -1):
                            op1 = global_block.ops[j]
                            if op1 in visited:
                                continue
                            found = False
                            for k in range(0, len(op1.output_names)):
                                outs = op1.output(op1.output_names[k])
                                for t in range(len(op_inputs)):
                                    for y in op_inputs[t]:
                                        if y in outs:
                                            found = True
                                            break
                                    if found:
                                        break
                                if found:
                                    break
                            if found:
                                if output_indexes[j] is True:
                                    warnings.warn(
                                        "unable to re-arrange dags order to combine distributed embedding ops"
                                    )
                                    return
                                queue.append(j)
                                visited.add(global_block.ops[j])
                        start = start + 1

                    queue.sort()
                    for index in queue:
                        desc = global_block.desc._insert_op(min_output_index)
                        desc.copy_from(global_block.ops[index].desc)
                        global_block.desc._remove_op(index + 1, index + 2)
                        global_block.ops[index].desc = desc
                        insert_op = global_block.ops.pop(index)
                        input_state = input_indexes.pop(index)
                        output_state = output_indexes.pop(index)
                        global_block.ops.insert(min_output_index, insert_op)
                        input_indexes.insert(min_output_index, input_state)
                        output_indexes.insert(min_output_index, output_state)
                        min_output_index = min_output_index + 1

                assert global_block.desc.op_size() == len(global_block.ops)
                for i in range(len(global_block.ops)):
                    assert global_block.desc.op(i) == global_block.ops[i].desc

        for param, ops in pull_sparse_ops.items():
            all_ops = program.global_block().ops
            op_device = ""
            if config.is_heter_ps_mode:
                op_device = ops[0].attr("op_device")
            inputs = [
                program.global_block().vars[op.input("Ids")[0]] for op in ops
            ]
            w = program.global_block().vars[ops[0].input("W")[0]]
            emb_size[param] = w.shape[1]

            grad_name = config.param_name_to_grad_name[w.name]

            table_id = -1

            for name, ctx in send_ctx.items():
                if grad_name in ctx.origin_varnames():
                    table_id = ctx.table_id()

            if table_id == -1:
                raise ValueError(
                    "can not find suitable sparse table, please check"
                )

            w_2_table_id[param] = table_id
            padding_idx = ops[0].attr("padding_idx")
            is_distributed = ops[0].attr("is_distributed")
            op_type = ops[0].type

            outputs = [
                program.global_block().vars[op.output("Out")[0]] for op in ops
            ]

            dag_check_up_and_reorder(program, inputs, outputs)

            op_idxs = [all_ops.index(op) for op in ops]

            for idx in op_idxs[::-1]:
                program.global_block()._remove_op(idx)

            inputs_idxs = [-1] * len(inputs)
            outputs_idxs = [len(program.global_block().ops) + 1] * len(outputs)

            for idx, op in enumerate(program.global_block().ops):
                for i in range(0, len(op.output_names)):
                    outs = op.output(op.output_names[i])
                    for in_id, in_var in enumerate(inputs):
                        if in_var.name in outs:
                            inputs_idxs[in_id] = max(idx, inputs_idxs[in_id])
                for i in range(0, len(op.input_names)):
                    ins = op.input(op.input_names[i])
                    for out_id, out_var in enumerate(outputs):
                        if out_var.name in ins:
                            outputs_idxs[out_id] = min(
                                idx, outputs_idxs[out_id]
                            )

            if min(outputs_idxs) - max(inputs_idxs) >= 1:
                if max(inputs_idxs) == -1:
                    distributed_idx = min(op_idxs)
                else:
                    distributed_idx = max(inputs_idxs) + 1

                if use_ps_gpu:
                    program.global_block()._insert_op(
                        index=distributed_idx,
                        type="pull_gpups_sparse",
                        inputs={"Ids": inputs, 'W': w},
                        outputs={"Out": outputs},
                        attrs={
                            "size": [w.shape[1] for i in inputs],
                            "is_distributed": True,
                            "is_sparse": True,
                        },
                    )
                else:
                    program.global_block()._insert_op(
                        index=distributed_idx,
                        type="distributed_lookup_table",
                        inputs={"Ids": inputs, 'W': w},
                        outputs={"Outputs": outputs},
                        attrs={
                            "is_distributed": is_distributed,
                            "padding_idx": padding_idx,
                            "table_id": table_id,
                            "lookup_table_version": op_type,
                            "op_device": op_device,
                        },
                    )
            else:
                for i in range(len(inputs_idxs)):
                    distributed_idx = op_idxs[i]

                    program.global_block()._insert_op(
                        index=distributed_idx,
                        type="distributed_lookup_table",
                        inputs={"Ids": [inputs[i]], 'W': w},
                        outputs={"Outputs": [outputs[i]]},
                        attrs={
                            "is_distributed": is_distributed,
                            "padding_idx": padding_idx,
                            "table_id": table_id,
                            "lookup_table_version": op_type,
                            "op_device": op_device,
                        },
                    )

    def _push_sparse_fuse(_program, push_sparse_ops, use_ps_gpu):
        if use_ps_gpu:
            # in ps_gpu_pass
            return
        if len(push_sparse_ops) == 0:
            return
        show = None
        clk = None
        use_entry = False
        for param, ops in push_sparse_ops.items():
            op_first = ops[0]
            break
        print(op_first)
        if op_first.has_attr("entry"):
            entry = op_first.attr("entry")
            entry = entry.split(':')
            if len(entry) == 3 and entry[0] == 'show_click_entry':
                show_var_name = entry[1]
                click_var_name = entry[2]
                if (
                    show_var_name in program.global_block().vars
                    and click_var_name in program.global_block().vars
                ):
                    show = program.global_block().vars[show_var_name]
                    clk = program.global_block().vars[click_var_name]
                    use_entry = True
                else:
                    warnings.warn(
                        'ShowClickEntry configured, but cannot find show/click var, will not use'
                    )

        if not use_entry:
            print('ShowClickEntry not configured, will not use')
            show = program.global_block().create_var(
                name="show",
                dtype=core.VarDesc.VarType.INT64,
                persistable=False,
                stop_gradient=True,
            )
            program.global_block()._insert_op(
                index=0,
                type='fill_constant',
                inputs={},
                outputs={'Out': show},
                attrs={
                    'shape': [1],
                    'dtype': show.dtype,
                    'value': 1,
                    # OP_ROLE_KEY: OpRole.Forward
                },
            )

            clk = program.global_block().create_var(
                name="clk",
                dtype=core.VarDesc.VarType.INT64,
                persistable=False,
                stop_gradient=True,
            )
            program.global_block()._insert_op(
                index=0,
                type='fill_constant',
                inputs={},
                outputs={'Out': clk},
                attrs={
                    'shape': [1],
                    'dtype': clk.dtype,
                    'value': 0,
                    # OP_ROLE_KEY: OpRole.Forward
                },
            )

        for param, ops in push_sparse_ops.items():
            all_ops = program.global_block().ops
            op_idxs = [all_ops.index(op) for op in ops]
            inputs = [
                program.global_block().vars[op.input("Ids")[0]] for op in ops
            ]
            w = program.global_block().vars[ops[0].output("W@GRAD")[0]]
            table_id = w_2_table_id[param]

            padding_idx = ops[0].attr("padding_idx")
            is_distributed = ops[0].attr("is_distributed")
            op_type = ops[0].type
            outputs = [
                program.global_block().vars[op.input("Out@GRAD")[0]]
                for op in ops
            ]

            for idx in op_idxs[::-1]:
                program.global_block()._remove_op(idx)

            #            if use_ps_gpu:
            #                program.global_block().append_op(
            #                    type="push_box_sparse",
            #                    inputs={"Ids": inputs,
            #                            'Out': outputs},
            #                    outputs={"Out": outputs},
            #                    attrs={
            #                        "size": w.shape[1],
            #                        "is_distributed": True,
            #                        "is_sparse": True
            #                    })
            #            else:
            program.global_block().append_op(
                type="distributed_push_sparse",
                inputs={
                    "Ids": inputs,
                    'W': w,
                    "Outputs": outputs,
                    "Shows": show,
                    "Clicks": clk,
                },
                outputs={"Outputs": outputs},
                attrs={
                    "is_distributed": is_distributed,
                    "padding_idx": padding_idx,
                    "table_id": table_id,
                    "size": emb_size[param],
                },
            )

    pull_sparse_ops, push_sparse_ops = _get_pull_sparse_ops(program)
    _pull_sparse_fuse(program, pull_sparse_ops, use_ps_gpu)
    _push_sparse_fuse(program, push_sparse_ops, use_ps_gpu)
    return program


def append_send_ops_pass(program, config):
    mode = config.get_distributed_mode()
    trainer_id = config.get_role_id()

    def _append_send_op(union_vars, queue, is_sparse, table_id):
        if queue == STEP_COUNTER:
            send_input_vars = []
        else:
            send_input_vars = [
                program.global_block().vars[union_var]
                for union_var in union_vars
            ]

        dummy_output = []
        if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
            dummy_output = program.global_block().create_var(
                name=framework.generate_control_dev_var_name()
            )

        program.global_block().append_op(
            type="send",
            inputs={"X": send_input_vars},
            outputs={"Out": dummy_output},
            attrs={
                "send_varnames": [queue],
                "is_sparse": is_sparse,
                "table_id": table_id,
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
            },
        )

        return dummy_output

    def _append_barrier_op(dummys):
        program.global_block().append_op(
            type="send_barrier",
            inputs={"X": dummys},
            outputs={"Out": []},
            attrs={
                "trainer_id": trainer_id,
                "half_async": True,
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
            },
        )

    dummys = []

    sends = config.get_the_one_trainer_send_context(
        split_dense_table=config.is_heter_ps_mode
    )

    for merged_name, send in sends.items():
        if send.is_sparse() and not config.is_geo_mode():
            continue
        is_sparse = 1 if send.is_sparse() else 0
        is_sparse = 2 if send.is_distributed() else is_sparse
        dummys.append(
            _append_send_op(
                send.origin_varnames(), merged_name, is_sparse, send.table_id()
            )
        )

    if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
        _append_barrier_op(dummys)

    return program


def init_from_server_pass(program, config):
    # 0' trainer do not need barrier, it will call barrier at the end init_worker
    if config.role_maker._is_first_worker():
        return program

    fetch_barrier_out = program.global_block().create_var(
        name=framework.generate_control_dev_var_name()
    )

    program.global_block().append_op(
        type="fetch_barrier",
        inputs={},
        outputs={"Out": fetch_barrier_out},
        attrs={
            "endpoints": config.get_ps_endpoints(),
            "trainer_id": config.get_role_id(),
            RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
        },
    )
    return program


def fake_init_ops_pass(program, config):
    origin_program = config.get_origin_main_program()

    def _get_sparse_table_names():
        dist_varnames = get_sparse_tablenames(origin_program, True)
        sparse_varnames = get_sparse_tablenames(origin_program, False)
        return list(set(dist_varnames + sparse_varnames))

    def _fake_init_sparsetable(sparse_table_names):
        # delete table init op
        for table_name in sparse_table_names:
            table_var = program.global_block().vars[table_name]
            table_param_init_op = []
            for op in program.global_block().ops:
                if table_name in op.output_arg_names:
                    table_param_init_op.append(op)
            init_op_num = len(table_param_init_op)
            if init_op_num != 1:
                raise ValueError(
                    "table init op num should be 1, now is " + str(init_op_num)
                )
            table_init_op = table_param_init_op[0]
            program.global_block().append_op(
                type="fake_init",
                inputs={},
                outputs={"Out": table_var},
                attrs={"shape": table_init_op.attr('shape')},
            )
            delete_ops(program.global_block(), table_param_init_op)

    sparse_tables = _get_sparse_table_names()
    _fake_init_sparsetable(sparse_tables)

    return program


def ps_gpu_pass(program):
    def _add_push_box_sparse_op(program):
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        for op in program.global_block().ops:
            if op.type != "pull_box_sparse" and op.type != "pull_gpups_sparse":
                continue
            grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
                op.desc, set(), []
            )
            for op_desc in grad_op_desc:
                new_op_desc = program.global_block().desc.append_op()
                new_op_desc.copy_from(op_desc)
                new_op_desc._set_attr(op_role_attr_name, backward)

    def _remove_lookup_table_grad_op_and_var(program):
        lookup_table_grad_var = {}
        remove_op_index = []
        remove_var = []
        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "lookup_table_grad":
                for name in op.output("W@GRAD"):
                    lookup_table_grad_var[name] = 1
                    remove_op_index.append(idx)
                    remove_var.append(name)
                for name in op.input("W"):
                    lookup_table_grad_var[name] = 1

        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "pull_box_sparse" or op.type == "pull_gpups_sparse":
                continue
            for key_name in op.input_names:
                for var in op.input(key_name):
                    if var in lookup_table_grad_var:
                        remove_op_index.append(idx)
                        break

        remove_op_index = list(set(remove_op_index))
        remove_op_index.sort(reverse=True)
        for idx in remove_op_index:
            program.global_block()._remove_op(idx)
        for name in remove_var:
            program.global_block()._remove_var(name)

    def _remove_optimizer_var(program):
        embedding_w = {}
        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "lookup_table_grad":
                for name in op.input("W"):
                    embedding_w[name] = 1

        optimize_vars = []
        optimize_op_role_vars = []
        optimize_need_delete_vars = []
        for op in _get_optimize_ops(program):
            for name in op.input("Param"):
                if name in embedding_w:
                    optimize_op_role_vars.extend(op.attr("op_role_var"))
                    for key_name in op.input_names:
                        if key_name == "LearningRate":
                            continue
                        for var in op.input(key_name):
                            optimize_vars.append(var)

        optimize_vars = list(set(optimize_vars))
        optimize_op_role_vars = list(set(optimize_op_role_vars))

        for var in optimize_vars:
            if var not in optimize_op_role_vars:
                optimize_need_delete_vars.append(var)
        need_delete_optimize_vars = list(set(optimize_need_delete_vars))

        for name in need_delete_optimize_vars:
            if program.global_block().has_var(name):
                program.global_block()._remove_var(name)

    _add_push_box_sparse_op(program)
    _remove_optimizer_var(program)
    _remove_lookup_table_grad_op_and_var(program)
    return program


def delete_extra_optimizes_pass(program, config):
    optimize_vars = []
    optimize_op_role_vars = []
    optimize_need_delete_vars = []

    origin_program = config.get_origin_main_program()
    for op in _get_optimize_ops(origin_program):
        optimize_vars.extend(op.input_arg_names)
        optimize_op_role_vars.extend(op.attr("op_role_var"))

    optimize_vars = list(set(optimize_vars))
    optimize_op_role_vars = list(set(optimize_op_role_vars))
    for var in optimize_vars:
        if var not in optimize_op_role_vars:
            optimize_need_delete_vars.append(var)
    need_delete_optimize_vars = list(set(optimize_need_delete_vars))

    init_ops = []
    for var in need_delete_optimize_vars:
        param_init_op = []
        for op in program.global_block().ops:
            if var in op.output_arg_names:
                param_init_op.append(op)
        init_ops.extend(param_init_op)
    delete_ops(program.global_block(), init_ops)

    for var in need_delete_optimize_vars:
        if program.global_block().has_var(var):
            program.global_block()._remove_var(var)

    return program


def find_heter_ops(program, default_device="cpu"):
    if default_device not in DEVICE_LIST:
        raise ValueError(
            f"Given device {default_device} is not in device list {DEVICE_LIST}"
        )

    def _is_heter_op(op, current_heter_device, default_device="cpu"):
        heter_devices = list(DEVICE_LIST)
        heter_devices.remove(default_device)
        op_device = op.attr("op_device")
        op_type = op.type
        if op_device in heter_devices:
            return True
        elif (
            op_type in COMMUNICATE_OPS_TYPE
            and current_heter_device != default_device
        ):
            # for distributed communicate ops: send & recv & barrier etc.
            # Todo: need update this method
            # op._set_attr('op_device', current_heter_device)
            return True
        elif op_device is None or op_device == default_device:
            op._set_attr('op_device', default_device)
            return False
        return False

    def _is_same_device(op, pre_device, default_device="cpu"):
        op_device = op.attr("op_device")
        if op_device == pre_device:
            return True
        if pre_device == default_device:
            return True
        return False

    def _append_heter_op(op, current_heter_block_ops, heter_ops):
        op_device = op.attr("op_device")
        if op_device not in heter_ops:
            heter_ops[op_device] = {}
        current_heter_block_ops.append(op)

    origin_program = program.clone()
    block = program.global_block()
    '''
       re-place sum op to fix bug for union forward backward op
    '''
    var2idx = {}
    op_list = list(block.ops)
    op_size = len(op_list)

    for i in range(op_size - 1, -1, -1):
        op_list = list(block.ops)
        op = op_list[i]
        if "_grad" in op.type:
            forward_op_type = op.type.split("_grad")[0]
            if (
                forward_op_type in SPARSE_OP_TYPE_DICT.keys()
                and op.attr('remote_prefetch') is True
            ):
                param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]
                if param_name in var2idx:
                    # insert sum op & remove sum op from var2idx and origin place
                    op_list = list(block.ops)
                    sum_op = op_list[var2idx[param_name]]
                    sum_op_inputs = {
                        sum_op.input_names[0]: [
                            block.vars[input]
                            for input in sum_op.input_arg_names
                        ]
                    }
                    sum_op_outputs = {
                        sum_op.output_names[0]: [
                            block.vars[output]
                            for output in sum_op.output_arg_names
                        ]
                    }
                    block._insert_op(
                        index=i + 1,
                        type=sum_op.type,
                        inputs=sum_op_inputs,
                        outputs=sum_op_outputs,
                        attrs=sum_op.all_attrs(),
                    )
                    block._remove_op(var2idx[param_name] + 1)
                    var2idx.pop(param_name)
                    for var_ in var2idx:
                        var2idx[var_] += 1
            elif forward_op_type == "elementwise_mul":
                """
                get output varname of pre op

                """
                output_vars_no_grad = []
                for key in op.output_names:
                    for varname in op.output(key):
                        if varname == "@EMPTY@":
                            continue
                        if "lod_tensor_blocking_queue" in varname:
                            continue
                        output_vars_no_grad.append(varname.split("@GRAD")[0])
                for no_grad_var in output_vars_no_grad:
                    if no_grad_var in var2idx:
                        """
                        insert sum op & remove sum op from var2idx and origin place

                        """
                        op_list = list(block.ops)
                        sum_op = op_list[var2idx[no_grad_var]]
                        sum_op_inputs = {
                            sum_op.input_names[0]: [
                                block.vars[input]
                                for input in sum_op.input_arg_names
                            ]
                        }
                        sum_op_outputs = {
                            sum_op.output_names[0]: [
                                block.vars[output]
                                for output in sum_op.output_arg_names
                            ]
                        }
                        block._insert_op(
                            index=i + 1,
                            type=sum_op.type,
                            inputs=sum_op_inputs,
                            outputs=sum_op_outputs,
                            attrs=sum_op.all_attrs(),
                        )
                        block._remove_op(var2idx[no_grad_var] + 1)
                        var2idx.pop(no_grad_var)
                        for var_ in var2idx:
                            var2idx[var_] += 1
        else:
            if op.type == "sum":
                var = op.output("Out")[0]
                if "@GRAD" in var:
                    origin_var = var.split("@GRAD")[0]
                    pre_op = op_list[i - 1]
                    if "_grad" in pre_op.type:
                        forward_op_type = pre_op.type.split("_grad")[0]
                        if (
                            forward_op_type in SPARSE_OP_TYPE_DICT.keys()
                            and pre_op.attr('remote_prefetch') is True
                        ):
                            param_name = pre_op.input(
                                SPARSE_OP_TYPE_DICT[forward_op_type]
                            )[0]
                            if param_name == origin_var and op.attr(
                                "op_device"
                            ) == pre_op.attr("op_device"):
                                continue
                            else:
                                var2idx[origin_var] = i
                        elif forward_op_type == "elementwise_mul":
                            output_vars = []
                            for key in pre_op.output_names:
                                for varname in pre_op.output(key):
                                    if varname == "@EMPTY@":
                                        continue
                                    if "lod_tensor_blocking_queue" in varname:
                                        continue
                                    output_vars.append(varname)
                            input_vars = []
                            for key in op.input_names:
                                for varname in op.input(key):
                                    if varname == "@EMPTY@":
                                        continue
                                    if "lod_tensor_blocking_queue" in varname:
                                        continue
                                    input_vars.append(varname)
                            is_match = False
                            for varname in output_vars:
                                if varname in input_vars:
                                    is_match = True
                                    break
                            if is_match:
                                continue
                            else:
                                var2idx[origin_var] = i
                    else:
                        var2idx[origin_var] = i

    origin_program = program.clone()
    block = program.global_block()

    program_block_ops = []
    default_ops = {default_device: {}}
    heter_ops = {}
    block_index = 0

    current_heter_block_ops = []
    current_default_block_ops = []
    current_heter_device = default_device
    is_heter = False
    for op in block.ops:
        if _is_heter_op(op, current_heter_device, default_device):
            # for gpu/xpu-op
            is_heter = True

            # for cpu-op block append
            if len(current_default_block_ops) > 1:
                default_ops[default_device][block_index] = (
                    current_default_block_ops
                )
                program_block_ops.append(current_default_block_ops)
                current_default_block_ops = []
                block_index += 1

            if _is_same_device(op, current_heter_device, default_device):
                # for gpu-op, gpu-op -> gpu-op,...
                current_heter_device = op.attr("op_device")
                _append_heter_op(op, current_heter_block_ops, heter_ops)
            else:
                # for gpu-op -> xpu-op, ...
                op_device = current_heter_block_ops[0].attr("op_device")
                heter_ops[op_device][block_index] = current_heter_block_ops
                program_block_ops.append(current_heter_block_ops)
                block_index += 1
                current_heter_block_ops = []
                current_heter_device = op.attr("op_device")
                _append_heter_op(op, current_heter_block_ops, heter_ops)

        elif is_heter:
            # for gpu/xpu-op -> cpu-op
            op_device = current_heter_block_ops[0].attr("op_device")
            heter_ops[op_device][block_index] = current_heter_block_ops
            program_block_ops.append(current_heter_block_ops)
            block_index += 1
            current_heter_block_ops = []
            current_heter_device = default_device
            is_heter = False
            current_default_block_ops.append(op)
        else:
            # for cpu-op
            current_default_block_ops.append(op)

    if current_default_block_ops != []:
        default_ops[default_device][block_index] = current_default_block_ops
        program_block_ops.append(current_default_block_ops)

    if current_heter_block_ops != []:
        op_device = current_heter_block_ops[0].attr("op_device")
        heter_ops[op_device][block_index] = current_heter_block_ops
        program_block_ops.append(current_heter_block_ops)

    if len(heter_ops) == 0:
        warnings.warn(
            "No heterogeneous OP was found in your program , "
            " please using paddle.static.device_guard() to run OPs on different device."
        )

    total_heter_ops = 0
    heter_blocks = 0
    for device in heter_ops.keys():
        heter_block_dict = heter_ops[device]
        heter_blocks += len(heter_block_dict)
        for _, heter_block in heter_block_dict.items():
            total_heter_ops += len(heter_block)
    print(
        f"There are {len(block.ops)} OPs in your main_program, and contains {total_heter_ops} heter-OPs which is made up of {heter_blocks} heter-blocks."
    )

    return origin_program, heter_ops, default_ops, program_block_ops


def create_heter_program(
    program,
    config,
    heter_program,
    program_block_ops_list,
    heter_ops,
    block_var_detail,
    current_device,
    stage_id,
):
    # This function mainly includes the following contents:
    # 1. For every heter block:
    #     a) copy heter device op from origin program
    #     b) create variables which belong to heter op:
    #         -> if variable is persistable, clone it in global_scope
    #         -> if variable is temp, create it in heter block
    #     c) create communicate related op as follow:
    #         joint_var.0_1 -> slice -> reshape -> origin_var
    #         origin_var -> origin_program
    #         reshape -> concat -> joint_var.1_2
    #     d) copy send op from origin program for var@grad which located in current heter block
    #     e) re-check every op in current block if its device is not current heter device
    # 2. Create send op for step counter in last heter-block
    # 3. Create Listen&Serv OP and Send&Recv OP for distributed training
    # 4. update CompileTimeStrategy for heter_program

    optimizer_block = []
    grad_to_block_id = []
    send_grad_var_list = []

    pre_block_idx = heter_program.num_blocks - 1
    stage_id = int(stage_id)
    print("stage id", stage_id)
    heter_block_ops_forward = program_block_ops_list[stage_id - 1]["forward"]

    heter_block_ops_backward = program_block_ops_list[stage_id - 1]["backward"]

    heter_block = heter_program._create_block(pre_block_idx)
    optimizer_block.append(heter_block)
    for _, op in enumerate(heter_block_ops_forward):
        block_append_op(heter_program, program, heter_block, op)

    entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"]
    add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
    exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"]
    add_vars_by_var_list(exit_vars, program, heter_program, heter_block)

    first_op_index_fp = len(heter_block.ops)

    if stage_id < len(program_block_ops_list):
        heter_block_bp = heter_program._create_block(pre_block_idx)
        optimizer_block.append(heter_block_bp)

        for _, op in enumerate(heter_block_ops_backward):
            block_append_op(heter_program, program, heter_block_bp, op)

        bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
            "entrance"
        ]
        add_vars_by_var_list(
            bp_entrance_vars, program, heter_program, heter_block_bp
        )
        bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
        add_vars_by_var_list(
            bp_exit_vars, program, heter_program, heter_block_bp
        )
        backward_comm_info = get_communicate_var_info(
            program, stage_id, bp_entrance_vars, type="backward"
        )

        grad_to_block_id.append(
            backward_comm_info["block_input_var_name"]
            + ":"
            + str(heter_block_bp.idx)
        )

    else:
        for _, op in enumerate(heter_block_ops_backward):
            block_append_op(heter_program, program, heter_block, op)

        bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
            "entrance"
        ]
        add_vars_by_var_list(
            bp_entrance_vars, program, heter_program, heter_block
        )
        bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
        add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block)

        heter_block_bp = heter_block

    forward_comm_info = get_communicate_var_info(
        program, stage_id, entrance_vars, type="forward"
    )

    grad_to_block_id.append(
        forward_comm_info["block_input_var_name"] + ":" + str(heter_block.idx)
    )

    first_op_index_bp = len(heter_block_bp.ops)

    if stage_id <= len(block_var_detail) - 1:
        static_var = insert_communicate_op(
            program,
            config,
            heter_block,
            stage_id,
            first_op_index_fp,
            block_var_detail,
            current_device,
        )
    static_var_bp = insert_communicate_op(
        program,
        config,
        heter_block_bp,
        stage_id,
        first_op_index_bp,
        block_var_detail,
        current_device,
        False,
    )

    # add send op
    send_grad_var_list = add_heter_send_op(
        program, heter_program, heter_block_bp, block_var_detail[stage_id - 1]
    )

    # ---------------
    # add step counter
    send_input_vars = []
    dummy_output = []
    pserver_endpoints = config.get_ps_endpoints()

    # optimizer_block[-1].append_op(
    #     type="send",
    #     inputs={"X": send_input_vars},
    #     outputs={"Out": dummy_output},
    #     attrs={
    #         "send_varnames": [STEP_COUNTER],
    #         "merge_add": True,
    #         "use_send_handler": False,
    #         "endpoints": pserver_endpoints
    #     })

    # add info in listen&serv
    attrs = {
        # "mode": "sync",
        # "trainers": config.get_trainers(),
        # "trainer_id": config.get_role_id() + config.get_trainers(),
        "message_to_block_id": grad_to_block_id,
        "optimize_blocks": optimizer_block,
        # runtime attribute
        "endpoint": config.get_heter_worker_endpoint(),
        "fanin": len(config.get_previous_stage_trainers()),
        "pserver_id": config.get_role_id(),
        "distributed_mode": config.get_distributed_mode(),
        "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
        RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
    }
    # append the listen_and_serv op
    heter_program.global_block().append_op(
        type="heter_listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs
    )
    check_heter_compile_time_strategy(program, config, send_grad_var_list)


def check_heter_compile_time_strategy(program, config, send_grad_var_list):
    origin_grad_var_list = []
    for _, var_grad in config.merged_variables_pairs:
        origin_grad_var_list.append(var_grad.merged_var.name)

    origin_grad_var_list = list(set(origin_grad_var_list))
    send_grad_var_list = list(set(send_grad_var_list))
    useless_grad_var_list = list(
        set(origin_grad_var_list) - set(send_grad_var_list)
    )

    for useless_grad_var in useless_grad_var_list:
        config.remove_var_pair_by_grad(useless_grad_var)


def create_trainer_program(
    program, origin_program, config, program_block_ops_list, block_var_detail
):
    # This function mainly includes the following contents:
    # 1. For every heter block in origin program
    #     a) delete heter op and related variables
    #     b) add send&recv op
    #     c) add communicate ops as follows:
    #         origin_var -> reshape -> concat -> joint_var.0_1
    #         send&recv op(send joint_var.0_1; recv joint_var.1_2)
    #         joint_var.1_2 -> slice -> reshape -> origin_var
    #     d) remove send op which related var@grad is not in trainer program
    # 2. check every op's device
    static_var = []
    for heter_block_index in range(1, len(program_block_ops_list)):
        ops_list = (
            program_block_ops_list[heter_block_index]["forward"]
            + program_block_ops_list[heter_block_index]["backward"]
        )
        static_var += replace_ops_by_communicate_op(
            program, config, heter_block_index, ops_list, block_var_detail
        )
        remove_trainer_send_op(
            program, config, heter_block_index, block_var_detail
        )

    optimizer_block = []
    grad_to_block_id = []

    bp_ops_list = program_block_ops_list[0]["backward"]
    delete_same_ops(program.global_block(), bp_ops_list)
    delete_trainer_useless_var(config, program, static_var)
    backward_block = create_backward_block(
        program, origin_program, config, bp_ops_list, block_var_detail
    )

    bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
    backward_comm_info = get_communicate_var_info(
        origin_program, 1, bp_entrance_vars, type="backward"
    )

    grad_to_block_id.append(
        backward_comm_info["block_input_var_name"]
        + ":"
        + str(backward_block.idx)
    )
    optimizer_block.append(backward_block)

    attrs = {
        # "mode": "sync",
        # "trainers": config.get_trainers(),
        # "trainer_id": config.get_role_id(),
        "message_to_block_id": grad_to_block_id,
        "optimize_blocks": optimizer_block,
        # runtime attribute
        "endpoint": config.get_trainer_endpoint(),  # get trainer endpoint
        "fanin": 0,  # get heter worker
        "pserver_id": config.get_role_id(),
        "distributed_mode": config.get_distributed_mode(),
        "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
        RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
    }
    # append the listen_and_serv op
    program.global_block()._insert_op(
        index=0,
        type="heter_listen_and_serv",
        inputs={'X': []},
        outputs={},
        attrs=attrs,
    )

    # TODO add check for bp block
    check_op_device(program.global_block(), DEFAULT_DEVICE)


def insert_communicate_op(
    origin_program,
    config,
    heter_block,
    stage_id,
    first_op_index,
    block_var_detail,
    device,
    is_forward=True,
):
    if is_forward:
        next_heter_worker_endpoints = config.get_next_stage_trainers()
        previous_heter_worker_endpoints = config.get_previous_stage_trainers()
        entrance_var = block_var_detail[stage_id]["forward"]["entrance"]
        comm_info = get_communicate_var_info(
            origin_program, stage_id + 1, entrance_var
        )

    else:
        next_heter_worker_endpoints = config.get_next_stage_trainers()
        # if next_heter_worker_endpoints == "":
        #    next_heter_worker_endpoints = []
        previous_heter_worker_endpoints = config.get_previous_stage_trainers()
        entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"]
        comm_info = get_communicate_var_info(
            origin_program, stage_id - 1, entrance_var, "backward"
        )

    heter_block._insert_op(
        index=first_op_index,
        type="send_and_recv",
        inputs={"X": heter_block.vars[entrance_var[0]]},
        outputs={"Out": []},
        attrs={
            "mode": "forward" if is_forward else "backward",
            "send_var_name": [*entrance_var, "microbatch_id"],
            "recv_var_name": [],
            "message_name": comm_info["block_input_var_name"],
            "next_endpoints": next_heter_worker_endpoints,
            "previous_endpoints": previous_heter_worker_endpoints,
            "trainer_id": config.get_role_id(),
            "op_device": device,
            RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
        },
    )

    return entrance_var


def create_backward_block(
    program, origin_program, config, bp_ops_list, block_var_detail
):
    pre_block_idx = program.num_blocks - 1
    heter_block = program._create_block(pre_block_idx)

    for _, op in enumerate(bp_ops_list):
        if op.type == "send":
            send_varnames = op.attr('send_varnames')
            is_skip = False
            for varname in send_varnames:
                if (
                    varname not in program.global_block().vars
                    and varname not in heter_block.vars
                ):
                    is_skip = True
                    break
            if is_skip is True:
                continue
        block_append_op(program, origin_program, heter_block, op)

    entrance_vars = block_var_detail[0]["backward"]["entrance"]
    add_vars_by_var_list(entrance_vars, origin_program, program, heter_block)
    exit_vars = block_var_detail[0]["backward"]["exit"]
    add_vars_by_var_list(exit_vars, origin_program, program, heter_block)
    return heter_block


def replace_ops_by_communicate_op(
    program, config, heter_block_index, ops_list, block_var_detail
):
    all_op = program.global_block().ops
    start_op = ops_list[0]
    first_op_idx = -1
    for op in all_op:
        if is_same_op(op, start_op):
            first_op_idx = all_op.index(op)
            break
    assert first_op_idx != -1
    delete_same_ops(program.global_block(), ops_list)

    entrance_var = []

    if heter_block_index == 1:
        mode = config.get_distributed_mode()
        next_heter_worker_endpoints = config.get_next_stage_trainers()

        entrance_var = block_var_detail[heter_block_index]["forward"][
            "entrance"
        ]

        comm_info = get_communicate_var_info(
            program, heter_block_index + 1, entrance_var
        )
        program.global_block()._insert_op(
            index=first_op_idx,
            type="send_and_recv",
            inputs={"X": program.global_block().vars[entrance_var[0]]},
            outputs={"Out": []},
            attrs={
                "mode": "forward",
                "send_var_name": [*entrance_var, "microbatch_id"],
                "recv_var_name": [],
                "message_name": comm_info["block_input_var_name"],
                "next_endpoints": next_heter_worker_endpoints,
                "previous_endpoints": [],
                "trainer_id": config.get_role_id(),
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
            },
        )

    return entrance_var


def remove_trainer_send_op(
    program, config, heter_block_index, block_var_detail
):
    # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
    # if trainer only do SEND, it has one var: var@GRAD
    # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
    persistables = (
        block_var_detail[heter_block_index]["forward"]["persistables"]
        + block_var_detail[heter_block_index]["backward"]["persistables"]
    )
    need_remove_send_op = []
    need_remove_grad_var = []
    for op in find_send_op(program):
        input_list, _ = find_op_input_output(
            program, program.global_block(), op
        )
        for var_name in input_list:
            origin_var_name = var_name.split("@GRAD")[0]
            if origin_var_name in persistables:
                need_remove_send_op.append(op)
                need_remove_grad_var.append(var_name)
    need_remove_send_op = list(set(need_remove_send_op))
    delete_ops(program.global_block(), need_remove_send_op)
    for grad_var_name in need_remove_grad_var:
        config.remove_var_pair_by_grad(grad_var_name)


def add_heter_send_op(program, heter_program, block, block_var_detail):
    def _get_send_op_dict():
        send_op_dict = {}
        send_op_list = find_send_op(program)
        for op in send_op_list:
            input_list, _ = find_op_input_output(
                program, program.global_block(), op
            )
            for var in input_list:
                send_op_dict[var] = op
        return send_op_dict

    # send_Op = { inputs{'X':[]},
    #             outputs{'Out':dummy_output},
    #             attrs{'send_varnames'"[]",
    #                   'is_sparse':int,
    #                   'table_id':int } }
    send_grad_var_list = []
    send_op_dict = _get_send_op_dict()
    table_dict = {}
    for persistable_var in block_var_detail["backward"]["persistables"]:
        # check var_name ==  var@GRAD
        if "@GRAD" not in persistable_var:
            continue
        if "GRAD" != persistable_var.split("@")[-1]:
            continue
        if persistable_var not in send_op_dict:
            continue
        send_op = send_op_dict[persistable_var]
        is_sparse = send_op.attr('is_sparse')
        table_id = send_op.attr('table_id')
        send_varnames = send_op.attr('send_varnames')
        send_grad_var_list.append(persistable_var)
        if table_id not in table_dict:
            table_dict[table_id] = {}
            table_dict[table_id]['var_list'] = []
            table_dict[table_id]['is_sparse'] = is_sparse
            table_dict[table_id]['send_varnames'] = send_varnames
        table_dict[table_id]['var_list'].append(persistable_var)

    for table_id in table_dict:
        dummy_output = block.create_var(
            name=framework.generate_control_dev_var_name()
        )
        send_input_vars = [
            block.vars[union_var]
            for union_var in table_dict[table_id]['var_list']
        ]
        block.append_op(
            type="send",
            inputs={"X": send_input_vars},
            outputs={"Out": dummy_output},
            attrs={
                "send_varnames": table_dict[table_id]['send_varnames'],
                "is_sparse": is_sparse,
                "table_id": table_id,
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
            },
        )

    return send_grad_var_list


def find_send_op(program):
    send_op_list = []
    for op in program.global_block().ops:
        if op.type == "send":
            send_op_list.append(op)
    return send_op_list


def get_communicate_var_info(
    program, block_index, entrance_var_list, type="forward"
):
    input_var_reshape_dim = []
    input_var_reshape_name = []

    if type == "forward":
        block_input_var_name = (
            f"forward_joint_{block_index - 1}_{block_index}@Heter"
        )
    else:
        block_input_var_name = (
            f"backward_joint_{block_index + 1}_{block_index}@Heter"
        )

    entrance_var_list.sort()
    # input
    # Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var
    for name in entrance_var_list:
        var = program.global_block().vars[name]
        shape = var.shape
        # if len(shape) < 2 or shape[0] != -1:
        #     raise ValueError(
        #         "Variable {} not support heter training. its shape is {}".
        #         format(name, shape))
        recv_var_dim = -1 * reduce(lambda x, y: x * y, shape, 1)
        input_var_reshape_dim.append(recv_var_dim)
        input_var_reshape_name.append(f"{name}.input_reshape@Heter")

    # output
    # var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR
    # for var_name in exit_var_list:
    #    var = program.global_block().vars[var_name]
    #    shape = var.shape
    #    # if len(shape) < 2 or shape[0] != -1:
    #    #     raise ValueError(
    #    #         "Variable {} not support heter training. its shape is {}".
    #    #         format(var_name, shape))
    #    send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape, 1)
    #    output_var_reshape_dim.append(send_reshape_dim)
    #    output_var_reshape_name.append("{}.output_reshape@Heter".format(
    #        var_name))

    info = {
        "input_var_reshape_dim": input_var_reshape_dim,
        "input_var_reshape_name": input_var_reshape_name,
        "block_input_var_name": block_input_var_name,
        #    "output_var_reshape_dim": output_var_reshape_dim,
        #    "output_var_reshape_name": output_var_reshape_name,
        #    "block_output_var_name": block_output_var_name
    }

    return info


def union_forward_gradient_op(program_block_ops_list):
    """
    before analyzing the input & output of each block in program_block_list, we should
    union the forward op and corresponding gradient op to eliminate the unnecessary variable
    transmit
    """
    """
    fix for 2emb model, re-place sum op

    """
    block_length = len(program_block_ops_list)
    '''
    ## get the final part
    final_part_idx = -1
    for i in range(block_length):
        op_list = program_block_ops_list[i]
        for op in op_list:
           if "_grad" in op.type:
              final_part_idx = i
              break
        if final_part_idx != -1:
            break

    ## eliminate wrong partition because of sum op
    ## lookup_table_v2_grad
    ## every lookup_table_v2_grad op block should follow a sum op
    var2idx  = {}

    for i in range(final_part_idx, block_length):
        op_list = program_block_ops_list[i]
        for j in range(len(op_list) - 1, -1, -1):
            op = op_list[j]
            #if op.type == "lookup_table_v2_grad":
            #   if j < len(op_list) - 1):
            #   else:
            #      ## get var and record place
            if _grad in op.type:
                forward_op_type = op.type.split("_grad")[0]
                if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \
                    and op.attr('remote_prefetch') is True:
                    param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0]

                    var2idx[] = [i,j] ##

    '''

    union_program_block_ops_list = []
    assert block_length % 2 != 0, (
        "the length of program_block_ops_list should be odd"
    )
    for i in range(0, block_length // 2):
        block_op_list = {"forward": program_block_ops_list[i]}
        block_op_list.update(
            {"backward": program_block_ops_list[block_length - 1 - i]}
        )
        union_program_block_ops_list.append(block_op_list)

    block_op_list = {"forward": [], "backward": []}
    for op in program_block_ops_list[block_length // 2]:
        if "_grad" not in op.type and not (op.type == "sum"):
            block_op_list["forward"].append(op)
        else:
            block_op_list["backward"].append(op)
    union_program_block_ops_list.append(block_op_list)
    return union_program_block_ops_list


def find_block_joints(program, program_block_ops_list, heter_ops):
    block_var_detail = find_entrance_exit_private(
        program, program_block_ops_list
    )
    block_var_detail = entrance_exit_check(
        program, program_block_ops_list, block_var_detail, heter_ops
    )
    block_var_detail = delete_block_useless_exit(
        program, program_block_ops_list, block_var_detail
    )

    return block_var_detail


def find_entrance_exit_private(program, program_block_ops_list):
    block_var_detail = []
    persistables = []
    for index, block_op_list in enumerate(program_block_ops_list):
        # forward
        block_input, block_output = find_ops_list_input_output(
            program, block_op_list["forward"]
        )
        persistables = screen_persistables(
            program, block_input
        ) + screen_persistables(program, block_output)
        # find entrance & exit
        block_private_vars = list(set(block_input) & set(block_output))
        block_entrance = list(set(block_input) - set(block_private_vars))
        block_exit = list(set(block_output) - set(block_private_vars))
        detail = {
            "forward": {
                "entrance": block_entrance,
                "exit": block_exit,
                "private": block_private_vars,
                "persistables": persistables,
            }
        }

        # backward
        bp_block_input, bp_block_output = find_ops_list_input_output(
            program, block_op_list["backward"]
        )
        bp_persistables = screen_persistables(
            program, bp_block_input
        ) + screen_persistables(program, bp_block_output)
        # find entrance & exit
        bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output))
        bp_block_entrance = list(
            set(bp_block_input) - set(bp_block_private_vars)
        )
        bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars))
        detail.update(
            {
                "backward": {
                    "entrance": bp_block_entrance,
                    "exit": bp_block_exit,
                    "private": bp_block_private_vars,
                    "persistables": bp_persistables,
                }
            }
        )
        block_var_detail.append(detail)
    return block_var_detail


def entrance_exit_check(
    program, program_block_ops_list, block_var_detail, heter_ops
):
    for index in range(len(block_var_detail) - 1, -1, -1):
        if index - 1 < 0:
            break
        previous_block_exit = block_var_detail[index - 1]["forward"]["exit"]
        previous_block_exit.sort()
        current_block_entrance = block_var_detail[index]["forward"]["entrance"]

        backward_entrance = block_var_detail[index]["backward"]["entrance"]

        forward_all = (
            block_var_detail[index]["forward"]["entrance"]
            + block_var_detail[index]["forward"]["private"]
            + block_var_detail[index]["forward"]["exit"]
        )

        for var in backward_entrance:
            if "@GRAD" not in var and var not in forward_all:
                current_block_entrance.append(var)

        current_block_entrance.sort()

        if previous_block_exit == current_block_entrance:
            continue
        exist_vars = list(
            set(previous_block_exit) & set(current_block_entrance)
        )
        need_add_vars = list(set(current_block_entrance) - set(exist_vars))
        # var in different stage should not be ignored, since they are not placed in the same program & device
        # need_add_vars = find_need_var_from_previous_block(
        #    need_add_vars, block_var_detail, index, heter_ops)

        previous_block_private = block_var_detail[index - 1]["forward"][
            "private"
        ]
        previous_block_entrance = block_var_detail[index - 1]["forward"][
            "entrance"
        ]
        for var in need_add_vars:
            if (
                var not in previous_block_private
                and var not in previous_block_entrance
            ):
                previous_block_entrance.append(var)
            previous_block_exit.append(var)
            if var not in current_block_entrance:
                current_block_entrance.append(var)

    for index in range(0, len(block_var_detail) - 1, 1):
        previous_block_exit = block_var_detail[index + 1]["backward"]["exit"]
        previous_block_exit.sort()
        current_block_entrance = block_var_detail[index]["backward"]["entrance"]

        current_block_entrance.sort()

        if previous_block_exit == current_block_entrance:
            continue
        exist_vars = list(
            set(previous_block_exit) & set(current_block_entrance)
        )
        need_add_vars = list(set(current_block_entrance) - set(exist_vars))
        need_ignore_vars = []
        for var in need_add_vars:
            if "@GRAD" not in var:
                need_ignore_vars.append(var)
        need_add_vars = list(
            set(need_add_vars).difference(set(need_ignore_vars))
        )
        previous_block_private = block_var_detail[index + 1]["backward"][
            "private"
        ]
        previous_block_entrance = block_var_detail[index + 1]["backward"][
            "entrance"
        ]
        for var in need_add_vars:
            if (
                var not in previous_block_private
                and var not in previous_block_entrance
            ):
                previous_block_entrance.append(var)
            previous_block_exit.append(var)
    return block_var_detail


def find_need_var_from_previous_block(
    need_add_vars, block_var_detail, current_index, heter_ops
):
    # create index_device_map
    index_device_map = {}
    for index in range(len(block_var_detail)):
        index_device_map[index] = DEFAULT_DEVICE
    for device in heter_ops:
        for index in heter_ops[device].keys():
            if index < len(block_var_detail):
                index_device_map[index] = device

    pre_index = current_index - 1
    need_ignore_var = []

    # if need_add_var in current device, no need communicate
    for var in need_add_vars:
        while pre_index >= 0:
            previous_block_private = block_var_detail[pre_index]["private"]
            previous_block_exit = block_var_detail[pre_index]["exit"]
            previous_block_entrance = block_var_detail[pre_index]["entrance"]
            total_var = (
                previous_block_private
                + previous_block_exit
                + previous_block_entrance
            )
            if var in total_var:
                if (
                    index_device_map[current_index]
                    == index_device_map[pre_index]
                    and index_device_map[current_index] == DEFAULT_DEVICE
                ):
                    need_ignore_var.append(var)
                    break
            pre_index -= 1

    need_add_vars = list(set(need_add_vars).difference(set(need_ignore_var)))
    return need_add_vars


def delete_block_useless_exit(
    program, program_block_ops_list, block_var_detail
):
    # forward
    for index in range(len(block_var_detail)):
        if index == len(block_var_detail) - 1:
            break
        current_block_exit = block_var_detail[index]["forward"]["exit"]
        next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"]
        need_delete_var = []
        for var in current_block_exit:
            if var not in next_block_entrance:
                need_delete_var.append(var)

        for var in need_delete_var:
            current_block_exit.remove(var)
    # backward
    for index in range(len(block_var_detail) - 1, -1, -1):
        if index - 1 < 0:
            break
        current_block_exit = block_var_detail[index]["backward"]["exit"]
        next_block_entrance = block_var_detail[index - 1]["backward"][
            "entrance"
        ]
        need_delete_var = []
        for var in current_block_exit:
            if var not in next_block_entrance:
                need_delete_var.append(var)
        for var in need_delete_var:
            current_block_exit.remove(var)

    return block_var_detail


def check_op_device(block, device):
    for op in block.ops:
        op._set_attr('op_device', device)


def screen_persistables(program, var_list):
    need_remove = []
    for var_name in var_list:
        if "@GRAD" in var_name:
            if "GRAD" != var_name.split("@")[-1]:
                continue
            origin_var_name = var_name.split("@GRAD")[0]
            var = program.global_block().vars[origin_var_name]
        else:
            var = program.global_block().vars[var_name]

        if paddle.static.is_persistable(var):
            need_remove.append(var_name)

    for var_name in need_remove:
        var_list.remove(var_name)
    return need_remove


def insert_reshape_op(
    program, block, index, var_name, new_var_name, new_var_shape=None
):
    input_var = block.vars[var_name]

    if new_var_name not in block.vars:
        out = block.create_var(
            name=new_var_name,
            shape=new_var_shape,
            dtype=input_var.dtype,
            type=input_var.type,
        )
    else:
        out = block.vars[new_var_name]
        new_var_shape = out.shape

    x_shape = block.create_var(
        name=f"{var_name}.xshape@Heter", dtype=input_var.dtype
    )
    block._insert_op(
        index=index,
        type="reshape2",
        inputs={"X": input_var},
        attrs={'shape': new_var_shape},
        outputs={"Out": out, "XShape": x_shape},
    )


def insert_send_concat_op(
    program, block, index, var_name_list, new_var_name, new_var_shape
):
    input_var_list = [block.vars[var_name] for var_name in var_name_list]

    out = program.global_block().create_var(
        name=new_var_name,
        shape=new_var_shape,
        dtype=input_var_list[0].dtype,
        type=input_var_list[0].type,
    )

    block._insert_op(
        index=index,
        type='concat',
        inputs={"X": input_var_list},
        outputs={'Out': [out]},
        attrs={'axis': -1, 'use_stack': False},
    )


def insert_recv_slice_op(
    program,
    block,
    index,
    var_name,
    var_shape,
    dtype,
    type,
    new_var_name_list,
    new_var_shape_list,
):
    if var_name not in program.global_block().vars:
        input_var = program.global_block().create_var(
            name=var_name, shape=var_shape, dtype=dtype, type=type
        )
    else:
        input_var = program.global_block().vars[var_name]

    out_list = []
    for i in range(len(new_var_name_list)):
        if new_var_name_list[i] not in block.vars:
            out = block.create_var(
                name=new_var_name_list[i],
                shape=new_var_shape_list[i],
                dtype=input_var.dtype,
                type=input_var.type,
            )
        else:
            out = block.vars[new_var_name_list[i]]
        out_list.append(out)

    start_index = 0
    end_index = 0
    for i in range(len(new_var_name_list)):
        starts = []
        ends = []
        attrs = {'axes': [1]}
        end_index += new_var_shape_list[i][1]
        starts.append(start_index)
        ends.append(end_index)
        attrs['starts'] = starts
        attrs['ends'] = ends

        block._insert_op(
            index=index,
            type='slice',
            inputs={'Input': input_var},
            attrs=attrs,
            outputs={'Out': out_list[i]},
        )
        start_index = end_index
        index += 1


def add_heter_trainer_useful_vars(
    config, program, heter_program, heter_block, static_var
):
    static_var = list(set(static_var))
    for var_name in static_var:
        if (
            var_name not in heter_program.global_block().vars
            and var_name not in heter_block.vars
        ):
            var = program.global_block().vars[var_name]
            if var.persistable:
                heter_program.global_block()._clone_variable(
                    var, force_persistable=False
                )
            else:
                heter_block._clone_variable(var, force_persistable=False)


def delete_trainer_useless_var(config, program, static_var):
    static_var = list(set(static_var))
    program_useful_var_list = []
    for op in program.global_block().ops:
        input_var_list, output_var_list = find_op_input_output(
            program, program.global_block(), op
        )
        op_var_list = list(set(input_var_list).union(set(output_var_list)))
        program_useful_var_list = list(
            set(program_useful_var_list).union(set(op_var_list))
        )
    program_useful_var_list += static_var
    program_useless_var_list = list(
        set(get_vars_name_in_block(program.global_block())).difference(
            set(program_useful_var_list)
        )
    )
    for var in program_useless_var_list:
        program.global_block()._remove_var(var)
    return program_useless_var_list


def block_append_op(program, origin_program, block, op):
    merge_ordereddict = origin_program.global_block().vars.copy()
    merge_ordereddict.update(block.vars)
    inputs = _get_input_map_from_op(merge_ordereddict, op)
    for key, varlist in inputs.items():
        if not isinstance(varlist, list):
            varlist = [varlist]
        for var in varlist:
            if (
                var.name not in program.global_block().vars
                and var.name not in block.vars
            ):
                if var.persistable:
                    program.global_block()._clone_variable(
                        var, force_persistable=False
                    )
                else:
                    block._clone_variable(var, force_persistable=False)

    outputs = _get_output_map_from_op(origin_program.global_block().vars, op)
    for key, varlist in outputs.items():
        if not isinstance(varlist, list):
            varlist = [varlist]
        for var in varlist:
            if (
                var.name not in program.global_block().vars
                and var.name not in block.vars
            ):
                if var.persistable:
                    program.global_block()._clone_variable(
                        var, force_persistable=False
                    )
                else:
                    block._clone_variable(var, force_persistable=False)

    if "_grad" not in op.type:
        # for forward op
        return block.append_op(
            type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()
        )
    else:
        # for grad op
        op_desc = op.desc
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()

        # append grad op
        new_op_desc = block.desc.append_op()
        new_op_desc.copy_from(op_desc)
        new_op_desc._set_attr(op_role_attr_name, backward)

        # set device grad
        if op.desc.has_attr(device_attr_name):
            op_device = op_desc.attr(device_attr_name)
            new_op_desc._set_attr(device_attr_name, op_device)
        block._sync_with_cpp()


def add_vars_by_var_list(var_name_list, origin_program, program, block):
    for var_name in var_name_list:
        if (
            var_name not in program.global_block().vars
            and var_name not in block.vars
        ):
            var = origin_program.global_block().vars[var_name]
            if var.persistable:
                program.global_block()._clone_variable(
                    var, force_persistable=False
                )
            else:
                block._clone_variable(var, force_persistable=False)


def get_varlist_from_op_map(var_map):
    var_list = []
    for key, varlist in var_map.items():
        if not isinstance(varlist, list):
            varlist = [varlist]
        for i in range(len(varlist)):
            var = varlist[i]
            var_list.append(var.name)
    return var_list


def find_ops_list_input_output(program, ops_list):
    input_var_list = []
    output_var_list = []
    for op in ops_list:
        inputs = _get_input_map_from_op(program.global_block().vars, op)
        input_var_list += get_varlist_from_op_map(inputs)
        outputs = _get_output_map_from_op(program.global_block().vars, op)
        output_var_list += get_varlist_from_op_map(outputs)

    input_var_list = list(set(input_var_list))
    output_var_list = list(set(output_var_list))
    return input_var_list, output_var_list


def find_op_input_output(program, block, op):
    input_var_list = []
    output_var_list = []
    inputs = _get_input_map_from_op(block.vars, op)
    input_var_list += get_varlist_from_op_map(inputs)
    outputs = _get_output_map_from_op(block.vars, op)
    output_var_list += get_varlist_from_op_map(outputs)
    input_var_list = list(set(input_var_list))
    output_var_list = list(set(output_var_list))
    return input_var_list, output_var_list


def get_vars_name_in_block(block):
    vars_list = block.vars.keys()
    vars_name_list = list(vars_list)
    return vars_name_list


def is_same_op(op1, op2):
    if str(op1) != str(op2):
        return False
    return True


def _get_input_map_from_op(varmap, op):
    """Returns a dict from op input name to the vars in varmap."""
    iomap = collections.OrderedDict()
    for key in op.input_names:
        vars = []
        for varname in op.input(key):
            if varname == "@EMPTY@":
                continue
            if "lod_tensor_blocking_queue" in varname:
                continue
            vars.append(varmap[varname])
        if len(vars) == 1:
            iomap[key] = vars[0]
        else:
            iomap[key] = vars
    return iomap


def _get_output_map_from_op(varmap, op):
    """Returns a dict from op output name to the vars in varmap."""
    iomap = collections.OrderedDict()
    for key in op.output_names:
        vars = []
        for varname in op.output(key):
            if varname == "@EMPTY@":
                continue
            if "lod_tensor_blocking_queue" in varname:
                continue
            vars.append(varmap[varname])
        if len(vars) == 1:
            iomap[key] = vars[0]
        else:
            iomap[key] = vars
    return iomap


def delete_same_ops(block, ops):
    for op in ops:
        try:
            for origin_op in block.ops:
                if is_same_op(origin_op, op):
                    idx = list(block.ops).index(origin_op)
                    block._remove_op(idx)
                    break
        except Exception as e:
            print(e)
