# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
from paddle.framework import core
from paddle.utils import unique_name

from ..base.private_helper_function import wait_server_ready

__all__ = []

OpRole = core.op_proto_and_checker_maker.OpRole

OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName()


def is_update_op(op):
    return (
        'Param' in op.input_names
        and 'Grad' in op.input_names
        and "LearningRate" in op.input_names
    )


def is_loss_grad_op(op):
    if OP_ROLE_KEY not in op.attr_names:
        return False
    op_role = int(op.all_attrs()[OP_ROLE_KEY])
    return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)


def is_backward_op(op):
    return OP_ROLE_KEY in op.attr_names and int(
        op.all_attrs()[OP_ROLE_KEY]
    ) & int(OpRole.Backward)


def is_optimizer_op(op):
    return OP_ROLE_KEY in op.attr_names and int(
        op.all_attrs()[OP_ROLE_KEY]
    ) & int(OpRole.Optimize)


class CollectiveHelper:
    def __init__(self, role_maker, nrings=1, wait_port=True):
        self.nrings = nrings
        self.wait_port = wait_port
        self.role_maker = role_maker

    def update_startup_program(self, startup_program=None):
        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = paddle.static.default_startup_program()

        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
        for ring_id in range(self.nrings):
            self._init_communicator(
                self.startup_program,
                current_endpoint,
                endpoints,
                self.role_maker._worker_index(),
                ring_id,
                self.wait_port,
            )
        self._broadcast_params()

    def _init_communicator(
        self,
        program,
        current_endpoint,
        endpoints,
        rank,
        ring_id,
        wait_port,
        global_ring_id=None,
        sync=True,
    ):
        # if current_endpoint is None, it means just for sync,
        # no group is created.
        endpoints_str = ",".join(endpoints)
        if current_endpoint:
            nranks = len(endpoints)
            other_endpoints = endpoints[:]
            other_endpoints.remove(current_endpoint)

        def _add_sync_by_allreduce(block):
            sync_var = block.create_var(
                name=unique_name.generate('sync_var'),
                dtype=core.VarDesc.VarType.INT32,
                persistable=False,
                stop_gradient=True,
            )
            block.append_op(
                type='fill_constant',
                inputs={},
                outputs={'Out': [sync_var]},
                attrs={
                    'shape': [1],
                    'dtype': sync_var.dtype,
                    'value': 1,
                    'force_cpu': False,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
            block.append_op(
                type='all_reduce',
                inputs={'x': [sync_var]},
                outputs={'out': [sync_var]},
                attrs={
                    'ring_id': global_ring_id,
                    'reduce_type': paddle.distributed.ReduceOp.SUM,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
            block.append_op(
                type='c_sync_calc_stream',
                inputs={'X': sync_var},
                outputs={'Out': sync_var},
                attrs={OP_ROLE_KEY: OpRole.Forward},
            )

        block = program.global_block()
        if current_endpoint is None:
            assert endpoints is None
            assert sync
            _add_sync_by_allreduce(block)
            return

        comm_id_var = block.create_var(
            name=unique_name.generate('comm_id'),
            persistable=True,
            type=core.VarDesc.VarType.RAW,
        )
        if core.is_compiled_with_cuda():
            block.append_op(
                type='c_gen_nccl_id',
                inputs={},
                outputs={'Out': comm_id_var},
                attrs={
                    'rank': rank,
                    'endpoint': current_endpoint,
                    'other_endpoints': other_endpoints,
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
            block.append_op(
                type='c_comm_init',
                inputs={'X': comm_id_var},
                outputs={},
                attrs={
                    'nranks': nranks,
                    'rank': rank,
                    'ring_id': ring_id,
                    'endpoints': endpoints_str,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
        elif core.is_compiled_with_xpu():
            block.append_op(
                type='c_gen_bkcl_id',
                inputs={},
                outputs={'Out': comm_id_var},
                attrs={
                    'rank': rank,
                    'endpoint': current_endpoint,
                    'other_endpoints': other_endpoints,
                    'ring_id': ring_id,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
            block.append_op(
                type='c_comm_init',
                inputs={'X': comm_id_var},
                outputs={},
                attrs={
                    'nranks': nranks,
                    'rank': rank,
                    'ring_id': ring_id,
                    'endpoints': endpoints_str,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )
        else:
            raise ValueError(
                "comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
            )
        if sync:
            _add_sync_by_allreduce(block)

    def _wait(self, current_endpoint, endpoints):
        assert self.wait_port
        other_endpoints = endpoints[:]
        other_endpoints.remove(current_endpoint)
        wait_server_ready(other_endpoints)

    def _broadcast_params(self):
        block = self.startup_program.global_block()
        ring_id = -1
        for param in block.iter_parameters():
            if param.is_distributed:
                continue

            ring_id = (ring_id + 1) % self.nrings
            block.append_op(
                type='broadcast',
                inputs={'x': param},
                outputs={'out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward,
                },
            )

        for ring_id in range(self.nrings):
            block.append_op(
                type='c_sync_comm_stream',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward},
            )
