# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings

import numpy as np

import paddle

from ...utils.log_utils import get_logger


class Converter:
    """
    Converter is a class object for auto parallel to convert tensors from
    one parallel strategy to another one. Tensors will merge and slice value
    with their strategy when strategies are different.
    """

    def __init__(self, tensors_dict, pre_strategy, cur_strategy):
        """
        Args:
            tensors_dict(dict): tensors' value of all ranks that to be converted.
                key is tensor's name(str), value is all ranks' data(list(numpy.ndarray))
            pre_strategy(dict): tensors' distributed attribute of last training process.
                key is tensor's name(str), value is tensor's distributed attribute in last
                training process.
            cur_strategy(dict): tensors' distributed attribute of current rank.
                key is tensor's name(str), value is tensor's distributed attribute in current
                rank.
        """
        self._tensors_dict = self._check_tensor_dict(tensors_dict)
        self._pre_strategy = self._check_pre_strategy(pre_strategy)
        self._cur_strategy = self._check_cur_strategy(cur_strategy)
        self._logger = get_logger(logging.INFO)

    def _check_tensor_dict(self, tensors_dict):
        if not tensors_dict:
            raise ValueError(
                "'tensors_dict' is None, "
                "the tensors to be converted cannot be None."
            )
        if not isinstance(tensors_dict, dict):
            raise TypeError(
                f"The type of 'tensors_dict' should be 'dict', but got '{type(tensors_dict)}'."
            )
        return tensors_dict

    def _check_pre_strategy(self, pre_strategy):
        if not pre_strategy:
            raise ValueError(
                "'pre_strategy' is None, there are not tensors in pre process."
            )
        if not isinstance(pre_strategy, dict):
            raise TypeError(
                "The type of 'pre_strategy' should be 'dict', "
                f"but got '{type(pre_strategy)}'."
            )
        return pre_strategy

    def _check_cur_strategy(self, cur_strategy):
        if not cur_strategy:
            warnings.warn(
                "'cur_strategy' is None, there are not tensors in cur process"
            )
        if not isinstance(cur_strategy, dict):
            raise TypeError(
                "The type of 'cur_strategy' should be 'dict', "
                f"but got '{type(cur_strategy)}'."
            )
        return cur_strategy

    def convert(self, strict=True):
        """
        Convert tensors

        Args:
            strict(bool): whether to strict convert tensor with tensor's name. If False, it will
            convert tensors by prefix matching. Otherwise, tensors will be converted with
            their name strictly.

        Returns:
            converted tensors(dict)

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:DISTRIBUTED)
                >>> import numpy as np
                >>> from paddle.distributed.auto_parallel.static.converter import Converter
                >>> complete_tensors = np.arange(4).reshape([2, 2])
                >>> partial_tensors = np.split(complete_tensors, 2, axis=0)
                >>> name = "tmp_0"
                >>> tensors_dict = {name: partial_tensors}
                >>> strategy_1 = {
                ...     name: {
                ...         "process_shape": [2],
                ...         "process_group": [0, 1],
                ...         "dims_mapping": [0, -1]
                ...     }
                ... }
                >>> strategy_2 = {
                ...     name: {
                ...         "process_shape": [2],
                ...         "process_group": [0, 1],
                ...         "dims_mapping": [-1, -1]
                ...     }
                ... }
                >>> converter = Converter(tensors_dict, strategy_1, strategy_2)
                >>> result = converter.convert()
                >>> # the result's value is equal to `complete_tensors`
        """
        tensors_dict = {}
        # the name which is in cur_process but not in pre_process
        tensor_not_in_pre = []
        # the name which is in pre_process but not in cur_process
        tensor_not_in_cur = []
        # the name which is in strategy but not in ckpt files
        tensor_not_in_ckpt = []
        self._logger.info("Start to convert tensors.")
        for tensor_name in self._cur_strategy:
            if tensor_name not in self._pre_strategy:
                tensor_not_in_pre.append(tensor_name)
                continue
            if tensor_name not in self._tensors_dict:
                tensor_not_in_ckpt.append(tensor_name)
                continue
            self._pre_name = tensor_name
            self._cur_name = tensor_name
            tensor_list = self._tensors_dict[tensor_name]
            pre_dist_attr = self._pre_strategy[tensor_name]
            cur_dist_attr = self._cur_strategy[tensor_name]
            try:
                tensors_dict[tensor_name] = Converter.merge_and_slice(
                    tensor_list, pre_dist_attr, cur_dist_attr
                )
            except ValueError as err:
                raise ValueError(
                    f"Fail to convert tensor '{tensor_name}'. {err}"
                )

        for tensor_name in self._pre_strategy:
            if tensor_name not in self._cur_strategy:
                tensor_not_in_cur.append(tensor_name)

        if not strict:
            (
                tensors_dict,
                tensor_match_with_pre,
                tensor_match_with_cur,
            ) = self.convert_with_prefix_match(
                tensors_dict, tensor_not_in_pre, tensor_not_in_cur
            )
        else:
            tensors_dict, tensor_match_with_pre, tensor_match_with_cur = (
                tensors_dict,
                [],
                [],
            )

        tensor_not_in_pre = set(tensor_not_in_pre) - set(tensor_match_with_pre)
        tensor_not_in_cur = set(tensor_not_in_cur) - set(tensor_match_with_cur)
        if tensor_not_in_pre:
            warnings.warn(
                f"tensors [{tensor_not_in_pre}] are not found in last training strategy."
            )
        if tensor_not_in_cur:
            warnings.warn(
                f"tensors [{tensor_not_in_cur}] are not found in current training strategy."
            )
        if tensor_not_in_ckpt:
            warnings.warn(
                f"tensors [{tensor_not_in_ckpt}] are found in pre_strategy, but are not found"
                "in checkpoint files, please check your checkpoint files."
            )

        return tensors_dict

    def convert_with_prefix_match(
        self, tensors_dict, tensor_not_in_pre, tensor_not_in_cur
    ):
        # the name which in cur_process and can match with pre_process
        tensor_match_with_pre = []
        # the name which in pre_process and can match with cur_process
        tensor_match_with_cur = []
        for cur_name in tensor_not_in_pre:
            prefix_name = cur_name
            while prefix_name.find("_") != -1:
                prefix_name = prefix_name[: prefix_name.rfind("_")]
                for pre_name in tensor_not_in_cur:
                    if prefix_name in pre_name:
                        # 'cur_name' of cur_process can match with 'pre_name' of pre_process
                        self._pre_name = pre_name
                        self._cur_name = cur_name
                        pre_tensor_list = self._tensors_dict[pre_name]
                        pre_dist_attr = self._pre_strategy[pre_name]
                        cur_dist_attr = self._cur_strategy[cur_name]
                        try:
                            tensors_dict[cur_name] = Converter.merge_and_slice(
                                pre_tensor_list, pre_dist_attr, cur_dist_attr
                            )
                        except ValueError as err:
                            raise ValueError(
                                f"Fail to convert tensor '{cur_name}' by '{pre_name}'. {err}"
                            )
                        self._logger.info(
                            f"tensor [{cur_name}] is matched with tensor [{pre_name}]"
                        )
                        tensor_match_with_pre.append(cur_name)
                        tensor_match_with_cur.append(pre_name)
                        break
                break

        return tensors_dict, tensor_match_with_pre, tensor_match_with_cur

    @staticmethod
    def merge_and_slice(tensor_list, pre_dist_attr, cur_dist_attr):
        """
        Merge tensors with previous dist_attr and slice tensors with current dist_attr

        Returns:
            tensor(numpy.narray): a tensor's value of current rank.
        """
        assert isinstance(tensor_list, list)
        assert all(isinstance(p, np.ndarray) for p in tensor_list)

        if pre_dist_attr == cur_dist_attr:
            # skip merge and slice tensor
            rank_id = paddle.distributed.get_rank()
            index = cur_dist_attr["process_group"].index(rank_id)
            tensor = tensor_list[index]
        else:
            pre_dims_mapping = pre_dist_attr["dims_mapping"]
            cur_dims_mapping = cur_dist_attr["dims_mapping"]

            if len(pre_dims_mapping) and (
                len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping
            ):
                # merge tensor
                tensor = Converter.merge_with_dist_attr(
                    tensor_list, pre_dist_attr
                )
            else:
                # skip merge tensor
                tensor = tensor_list[0]

            if len(cur_dims_mapping) and (
                len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping
            ):
                # slice tensor
                tensor = Converter.slice_with_dist_attr(tensor, cur_dist_attr)

        return tensor

    @staticmethod
    def merge_with_dist_attr(tensor_list, dist_attr):
        """Merge tensor with distributed attribute"""
        from .reshard import Resharder

        dims_mapping = dist_attr["dims_mapping"]
        process_shape = dist_attr["process_shape"]
        process_group = dist_attr["process_group"]
        # get the complete shape of the tensor
        complete_shape = Resharder.compute_complete_shape(
            tensor_list[0].shape, process_shape, dims_mapping
        )
        # merge the tensor with dist_attr
        partition_tensor_list = []
        merged_partition = []
        for process in process_group:
            partition_index = Resharder.compute_partition_index(
                process,
                complete_shape,
                dims_mapping,
                process_shape,
                process_group,
            )
            index = process_group.index(process)
            if partition_index not in merged_partition:
                merged_partition.append(partition_index)
                Converter.merge(
                    partition_tensor_list,
                    tensor_list[index],
                    partition_index,
                    complete_shape,
                )

        if len(partition_tensor_list) != 1:
            raise ValueError(
                f"Fail to merge tensor with dist_attr '{dist_attr}'."
            )
        complete_tensor = partition_tensor_list[0][0]
        return complete_tensor

    @staticmethod
    def slice_with_dist_attr(tensor, dist_attr):
        """Slice tensor with distributed attribute"""
        dims_mapping = dist_attr["dims_mapping"]
        if len(dims_mapping) == 0:
            # NOTE: scalar tensor no need to split
            return tensor
        process_shape = dist_attr["process_shape"]
        process_group = dist_attr["process_group"]
        # slice the tensor with dist_attr
        partition_index_list = Converter._get_split_indices(
            tensor.shape, dims_mapping, process_shape, process_group
        )
        sliced_tensor_list = Converter.split(
            tensor, partition_index_list, len(partition_index_list)
        )
        # get the current tensor's index in sliced_tensor_list
        rank_id = paddle.distributed.get_rank()
        sliced_tensor_index = Converter._get_sliced_index(
            rank_id, tensor.shape, dims_mapping, process_shape, process_group
        )
        if sliced_tensor_index not in range(len(sliced_tensor_list)):
            raise ValueError(
                f"Fail to slice tensor with dist_attr '{dist_attr}'."
            )
        sliced_tensor = sliced_tensor_list[sliced_tensor_index]
        return sliced_tensor

    @staticmethod
    def merge(partition_tensor_list, tensor, partition_index, complete_shape):
        """
        Merge partial tensors to a complete.

        Returns:
            None

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:DISTRIBUTED)
                >>> import numpy as np
                >>> import paddle
                >>> from paddle.distributed.auto_parallel.static.converter import Converter
                >>> partition_tensor_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])]
                >>> tensor = np.array([[[1.13, 1.14]]])
                >>> partition_index = [[0,1],[0,1],[2,4]]
                >>> complete_shape = [3, 2]

                >>> Converter.merge(partition_tensor_list, tensor, partition_index, complete_shape)
                >>> print(partition_tensor_list)
                [(array([[[1.11, 1.12, 1.13, 1.14]]]), [[0, 1], [0, 1], [0, 4]])]
        """
        from .reshard import Resharder

        if len(partition_tensor_list) == 1:
            is_complete_data = True
            for idx, item in enumerate(partition_tensor_list[0][1]):
                if item[0] != 0 or item[1] != complete_shape[idx]:
                    is_complete_data = False
                    break
            if is_complete_data:
                return

        if not partition_tensor_list:
            partition_tensor_list.append((tensor, partition_index))
        else:
            i = 0
            while i < len(partition_tensor_list):
                (
                    concat_axis,
                    first_order,
                    new_partition,
                ) = Resharder.compute_concat_info(
                    partition_tensor_list[i][1], partition_index
                )
                if concat_axis != -1:
                    if first_order == 0:
                        new_tensor = np.concatenate(
                            (partition_tensor_list[i][0], tensor),
                            axis=concat_axis,
                        )
                    else:
                        new_tensor = np.concatenate(
                            (tensor, partition_tensor_list[i][0]),
                            axis=concat_axis,
                        )

                    partition_tensor_list.pop(i)
                    Converter.merge(
                        partition_tensor_list,
                        new_tensor,
                        new_partition,
                        complete_shape,
                    )
                    break
                i += 1

    @staticmethod
    def split(complete_tensor, partition_index_list, length):
        """
        Slice a complete tensor.

        Returns:
            sliced_tensor_list(list): sliced tensors with 'partition_index_list'

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:DISTRIBUTED)
                >>> import numpy as np
                >>> from paddle.distributed.auto_parallel.static.converter import Converter
                >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]])
                >>> rank = 2
                >>> complete_shape = [1, 1, 6]
                >>> dims_mapping = [-1, -1, 0]
                >>> process_shape = [3]
                >>> process_group = [0, 1, 2]

                >>> sliced_tensor_list = Converter.split(complete_tensor, [[], [], [2, 4]], 3)
                >>> print(sliced_tensor_list)
                [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])]
        """
        sliced_tensor_list = []
        axis = len(complete_tensor.shape) - length
        sliced_tensor = np.split(
            complete_tensor, partition_index_list[axis], axis=axis
        )
        if length == 1:
            return sliced_tensor
        for tensor in sliced_tensor:
            sliced_tensor_list.extend(
                Converter.split(tensor, partition_index_list, length - 1)
            )
        return sliced_tensor_list

    @staticmethod
    def _get_split_indices(
        complete_shape, dims_mapping, process_shape, process_group
    ):
        """
        Get split indices of every dimension.

        Returns:
            split_indices_list(list): the split indices of every dimension of the tensor

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:DISTRIBUTED)
                >>> import numpy as np
                >>> from paddle.distributed.auto_parallel.static.utils import _get_split_indices
                >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]])
                >>> complete_shape = [1, 1, 6]
                >>> dims_mapping = [-1, -1, 0]
                >>> process_shape = [3]
                >>> process_group = [0, 1, 2]

                >>> index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
                >>> print(index)
                [[], [], [2, 4]]
        """
        from .reshard import Resharder

        split_indices_list = []
        for process in process_group:
            partition_index = Resharder.compute_partition_index(
                process,
                complete_shape,
                dims_mapping,
                process_shape,
                process_group,
            )
            if split_indices_list:
                for dim in range(len(partition_index)):
                    split_indices_list[dim].extend(partition_index[dim])
            else:
                split_indices_list = partition_index
        split_indices_list = list(
            map(
                lambda x, y: list(set(x) - {y} - {0}),
                split_indices_list,
                complete_shape,
            )
        )
        split_indices_list = [sorted(x) for x in split_indices_list]
        return split_indices_list

    @staticmethod
    def _get_sliced_index(
        rank_id, complete_shape, dims_mapping, process_shape, process_group
    ):
        """
        Get sliced_tensor's index of current rank in all sliced tensors list.

        Returns:
            sliced_tensor_index(int): the index of sliced tensor in sliced_tensor_list

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:DISTRIBUTED)
                >>> import numpy as np
                >>> from paddle.distributed.auto_parallel.static.converter import Converter
                >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]])
                >>> rank = 2
                >>> complete_shape = [1, 1, 6]
                >>> dims_mapping = [-1, -1, 0]
                >>> process_shape = [3]
                >>> process_group = [0, 1, 2]

                >>> index = Converter._get_sliced_index(rank, complete_shape, dims_mapping,
                ...                                 process_shape, process_group)
                >>> print(index)
                2
        """
        from .reshard import Resharder

        partition_index = Resharder.compute_partition_index(
            rank_id, complete_shape, dims_mapping, process_shape, process_group
        )
        sliced_index = 0
        for i, shape in enumerate(complete_shape):
            if dims_mapping[i] == -1:
                slice_shape = shape
            else:
                slice_shape = shape // process_shape[dims_mapping[i]]
            if slice_shape == 1:
                index = partition_index[i][0]
            else:
                index = (partition_index[i][0] + 1) // slice_shape
            sliced_index = sliced_index * (shape // slice_shape) + index
        return sliced_index
