#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of Role Makers."""

from __future__ import annotations

import os
import re
import time
import warnings
from multiprocessing import Manager, Process
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import numpy as np

import paddle
from paddle.base import core
from paddle.distributed.fleet.base.private_helper_function import (
    wait_server_ready,
)

from ...backup_env import getenv_or_backup

if TYPE_CHECKING:
    import numpy.typing as npt

__all__ = []


class Role:
    WORKER: ClassVar[Literal[1]] = 1
    SERVER: ClassVar[Literal[2]] = 2
    HETER_WORKER: ClassVar[Literal[3]] = 3
    ALL: ClassVar[Literal[4]] = 4
    COORDINATOR: ClassVar[Literal[5]] = 5


class Gloo:
    """
    Gloo is a universal class for barrier and collective communication
    """

    class RENDEZVOUS:
        HDFS = 1
        FILE = 2
        HTTP = 3

    def __init__(self):
        self._worker_comm = None
        self._server_comm = None
        self._nodes_comm = None

        self._comm_world = ["worker", "server", "all"]
        self._err_init = (
            "gloo is not initialized, will not communicator with other nodes"
        )
        self._err_type = "gloo initialized error, please check arguments"
        self._err_world = (
            f"argument error, comm_world must in {self._comm_world}"
        )

        self._is_initialized = False
        self._init_timeout_seconds = 3600
        self._run_timeout_seconds = 9999999

        self._rendezvous = None
        self._role = None
        self._iface = None

        self._role_id = -1
        self._worker_num = -1
        self._server_num = -1
        self._need_init_all = False

    def init(
        self,
        rendezvous,
        role,
        role_id,
        worker_num,
        server_num,
        need_init_all=False,
        kwargs=None,
    ):
        self._rendezvous = rendezvous
        self._role = role
        self._role_id = role_id
        self._worker_num = worker_num
        self._server_num = server_num
        self._need_init_all = need_init_all
        self._iface = ""
        self._prefix = kwargs.get("store.prefix", "")

        http_server = None
        if self._rendezvous == Gloo.RENDEZVOUS.HDFS:
            dfs_name = kwargs.get("dfs.name", "")
            dfs_ugi = kwargs.get("dfs.ugi", "")
            dfs_path = kwargs.get("dfs.path", "")

            if not dfs_name or not dfs_ugi or not dfs_path:
                raise ValueError(self._err_type)
            self._init_dfs(dfs_name, dfs_ugi, dfs_path, self._prefix)

        elif self._rendezvous == Gloo.RENDEZVOUS.FILE:
            fs_path = kwargs.get("dfs.path", "")

            if not fs_path:
                raise ValueError(self._err_type)
            self._init_fs(fs_path, self._prefix)

        elif self._rendezvous == Gloo.RENDEZVOUS.HTTP:
            ip = kwargs.get("http.host", "")
            port = kwargs.get("http.port", "")
            start_http_server = kwargs.get("start_http_server", False)
            http_server_d = kwargs.get("http_server_d")

            if not ip or not port:
                raise ValueError(self._err_type)
            http_server = self._init_http(
                ip, port, self._prefix, start_http_server, http_server_d
            )
        else:
            raise ValueError(self._err_type)

        self._is_initialized = True
        self._http_server = http_server

    def _init_fs(self, fs_path, prefix):
        def init(rank, nodes, role):
            gloo = core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(
                self._init_timeout_seconds, self._run_timeout_seconds
            )
            gloo.set_hdfs_store(os.path.join(fs_path, role), "", "")
            gloo.init()
            return gloo

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
        else:
            rank, nodes = self._get_rank_nodes(Role.SERVER)
            gloo = init(rank, nodes, "SERVER")
            self._server_comm = gloo

        if self._need_init_all:
            rank, nodes = self._get_rank_nodes(Role.ALL)
            gloo = init(rank, nodes, "ALL")
            self._nodes_comm = gloo

    def _init_dfs(self, dfs_name, dfs_ugi, dfs_path, prefix):
        def init(rank, nodes, role):
            gloo = core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(
                self._init_timeout_seconds, self._run_timeout_seconds
            )
            gloo.set_hdfs_store(os.path.join(dfs_path, role), dfs_name, dfs_ugi)
            gloo.init()
            return gloo

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
        else:
            rank, nodes = self._get_rank_nodes(Role.SERVER)
            gloo = init(rank, nodes, "SERVER")
            self._server_comm = gloo

        if self._need_init_all:
            rank, nodes = self._get_rank_nodes(Role.ALL)
            gloo = init(rank, nodes, "ALL")
            self._nodes_comm = gloo

    def _init_http(self, ip, port, prefix, start_http_server, http_server_d):
        def __start_kv_server(http_server_d, size_d):
            print(f"start http_server: {port}, {size_d}")
            from paddle.distributed.fleet.utils.http_server import KVServer

            http_server = KVServer(port, size_d)
            http_server.start()
            wait_seconds = 5
            while (
                http_server_d.get("running", False)
                or not http_server.should_stop()
            ):
                time.sleep(wait_seconds)
            http_server.stop()

        def init_kv_server(http_server_d):
            worker_key = prefix + '_' + 'worker'
            size_d = {
                worker_key: self._worker_num,
            }
            print(f"worker_key:{worker_key}, size: {size_d}")

            http_server_d["running"] = True
            # child process for http server
            _http_server = Process(
                target=__start_kv_server, args=(http_server_d, size_d)
            )
            _http_server.daemon = True
            # set running status to True
            # start child process
            _http_server.start()
            return _http_server

        def init(rank, nodes, role):
            gloo = core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(
                self._init_timeout_seconds, self._run_timeout_seconds
            )
            gloo.set_http_store(ip, port, 'worker')
            ep = ":".join([ip, str(port)])
            wait_server_ready([ep])
            gloo.init()
            return gloo

        port = int(port)

        if start_http_server:
            print("to start http_server")
            http_server = init_kv_server(http_server_d)

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
        # TODO (sandyhouse): initialize gloo for server and all

        # the closing of kv server may cause gloo init failure
        # since it depend on the full mesh connection
        # e.g. 0 connected with 1,2,3 while 2-3 not connected yet
        # TODO(kuizhiqing)
        if start_http_server:
            http_server_d["running"] = False
            http_server.join()

    def _get_rank_nodes(self, role):
        nodes = 0
        rank = -1

        if role == Role.WORKER:
            nodes = self._worker_num
            rank = self._role_id
        elif role == Role.SERVER:
            nodes = self._server_num
            rank = self._role_id
        elif role == Role.ALL:
            nodes = self._worker_num + self._server_num

            if self._role == Role.WORKER:
                rank = self._role_id
            else:
                rank = self._worker_num + self._role_id
        else:
            ValueError(self._err_type)

        return rank, nodes

    def __get_default_iface(self):
        """
        get default physical interface
        """
        default1 = self.__get_default_iface_from_gateway()
        default2 = self.__get_default_iface_from_interfaces()
        return default2 if default1 == "lo" else default1

    def __get_default_iface_from_gateway(self):
        """
        get default physical interface
        """
        res = os.popen("route -A inet").read().strip().split("\n")

        gateway_idx = None
        iface_idx = None
        for item in res:
            item = item.split()
            if "Gateway" in item and "Iface" in item:
                gateway_idx = item.index("Gateway")
                iface_idx = item.index("Iface")
            elif gateway_idx is not None and iface_idx is not None:
                gateway = None
                if len(item) > gateway_idx:
                    gateway = item[gateway_idx]
                if (
                    gateway
                    and gateway != '*'
                    and gateway != "0.0.0.0"
                    and len(item) > iface_idx
                ):
                    return item[iface_idx]
        return "lo"

    def __get_default_iface_from_interfaces(self):
        """
        get default physical interface
        """
        res = (
            os.popen("ip -f inet addr | awk NR%3==1").read().strip().split("\n")
        )
        for item in res:
            if "BROADCAST" in item:
                return item.split(":")[1].strip()
        return "lo"

    def barrier(self, comm_world):
        """
        dummy barrier, do nothing
        """
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        if comm_world == "worker":
            self._worker_comm.barrier()
        elif comm_world == "server":
            self._server_comm.barrier()
        else:
            self._nodes_comm.barrier()

    def all_reduce(self, input, mode="sum", comm_world="worker"):
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return input

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        input = np.array(input)
        input_shape = input.shape
        input_list = input.reshape(-1).tolist()

        self.barrier(comm_world)

        if comm_world == "worker":
            ans = self._worker_comm.all_reduce(input_list, mode)
        elif comm_world == "server":
            ans = self._server_comm.all_reduce(input_list, mode)
        else:
            ans = self._nodes_comm.all_reduce(input_list, mode)

        output = np.array(ans).reshape(input_shape)
        return output

    def all_gather(self, input, comm_world="worker"):
        """
        dummy all gather, do nothing
        Args:
            obj(any): obj to do all gather
        """
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return input

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        if comm_world == "worker":
            output = self._worker_comm.all_gather(input)
        elif comm_world == "server":
            output = self._server_comm.all_gather(input)
        else:
            output = self._nodes_comm.all_gather(input)

        return output


class RoleMakerBase:
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

    def __init__(self):
        self._worker_endpoints = []
        self._server_endpoints = []
        self._cur_endpoint = ""
        self._role_is_generated = False
        self._role = None
        self._current_id = -1

    def _is_worker(self):
        """
        return is_worker() of current process
        """
        raise NotImplementedError("Please implement this method in child class")

    def _is_server(self):
        """
        return is_server() of current process
        """
        raise NotImplementedError("Please implement this method in child class")

    def _is_first_worker(self):
        """
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
        """
        raise NotImplementedError("Please implement this method in child class")

    def _worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

    def _server_num(self):
        """
        Get current total server number.

        Returns:
            int: server number
        """
        raise NotImplementedError("Please implement this method in child class")

    def _worker_index(self):
        """
        Get current worker id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

    def _server_index(self):
        """
        Get current server id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

    def _role_id(self):
        """
        Get current id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

    def _node_num(self):
        """
        Get the training node number
        Returns:
            int: node num
        """
        raise NotImplementedError("Please implement this method in child class")

    def _get_trainer_endpoints(self):
        """
        return trainer endpoints
        """
        return self._worker_endpoints

    def _get_pserver_endpoints(self):
        """
        return pserver endpoints
        """
        return self._server_endpoints

    def to_string(self):
        return f"role: {self._role}, current_id: {self._current_id}, worker_endpoints: {self._worker_endpoints}, server_endpoints: {self._server_endpoints}"

    def _all_gather(self, input, comm_world="worker"):
        print("warning: RoleMakerBase does not have all gather worker.")

    def _all_reduce(self, input, mode="sum", comm_world="worker"):
        """
        Args:
            input(list/numpy.array): array of one dim
            output(list/numpy.array): array of one dim
            mode(str): "sum" or "min" or "max"
        """
        print("warning: RoleMakerBase does not have all reduce worker.")

    def _barrier(self, comm_world):
        """
        barrier between trainers if current role is TRAINER
        """
        print("warning: RoleMakerBase does not have barrier worker.")

    # def _is_heter_worker(self):
    #    """
    #    Return is_heter_worker() of current process
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    # def _heter_worker_num(self):
    #    """
    #    Get current total heter-worker number.
    #
    #    Returns:
    #        int: heter_worker number
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    # def _get_heter_worker_endpoints(self):
    #    """
    #    Returns:
    #        string: all heter_trainers'endpoints
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    # def _get_heter_worker_endpoint(self):
    #    """
    #    Returns:
    #        int: corresponding heter_trainer's endpoint
    #    """
    #    raise NotImplementedError("Please implement this method in child class")


class PaddleCloudRoleMaker(RoleMakerBase):
    """
    PaddleCloudRoleMaker is an interface for distributed configuration initialization based on obtaining distributed related information from environment variables.

    Examples:
        .. code-block:: python

            >>> import os
            >>> import paddle.distributed.fleet as fleet

            >>> os.environ["PADDLE_PSERVER_NUMS"] = "2"
            >>> os.environ["PADDLE_TRAINERS_NUM"] = "2"

            >>> os.environ["POD_IP"] = "127.0.0.1"
            >>> os.environ["PADDLE_PORT"] = "36001"
            >>> os.environ["TRAINING_ROLE"] = "PSERVER"
            >>> os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"

            >>> os.environ["PADDLE_TRAINER_ID"] = "0"

            >>> fleet.PaddleCloudRoleMaker(is_collective=False)

    """

    def __init__(self, is_collective: bool = False, **kwargs: Any) -> None:
        super().__init__()
        self._is_collective = is_collective
        self._non_distributed = False

        self._kwargs = kwargs
        self._role_is_generated = False

        # for heterps
        self._stage_id = 1
        self._stage_num = 1
        self._next_heter_trainer_endpoints = []
        self._previous_heter_trainer_endpoints = []
        self._heter_trainer_endpoints = []
        self._heter_trainer_device = "cpu"
        self._heter_trainer_device_type = "cpu"
        self._is_heter_parameter_server_mode = False
        self._stage_trainers = []

        self._server_endpoints = []
        self._worker_endpoints = []
        self._coordinator_endpoints = None
        self._with_coordinator = False

        self._gloo = Gloo()  # gloo instance

    def _barrier(self, comm_world: str) -> None:
        self._gloo.barrier(comm_world)

    def _all_gather(
        self, input: Any, comm_world: str = "worker"
    ) -> list[float]:
        return self._gloo.all_gather(input, comm_world)

    def _all_reduce(
        self, input: Any, mode: str = "sum", comm_world: str = "worker"
    ) -> npt.NDArray[Any]:
        return self._gloo.all_reduce(input, mode, comm_world)

    def _heter_device(self) -> str:
        """
        return the heter device that current heter worker is using
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._heter_trainer_device

    def _heter_device_type(self) -> str:
        """
        return the heter device type that current heter worker is using
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._heter_trainer_device_type

    def _get_stage_id(self) -> int:
        """
        return stage id of current heter worker
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_id

    def _get_stage_trainers(self) -> list[int]:
        """
        return trainer num of all stages
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_trainers

    def _get_num_stage(self) -> int:
        """
        return stage num
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_num

    def _is_worker(self) -> bool:
        """
        whether current process is worker
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._role == Role.WORKER

    def _is_server(self) -> bool:
        """
        whether current process is server
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._role == Role.SERVER

    def _is_coordinator(self) -> bool:
        if not self._role_is_generated:
            self._generate_role()
        return self._role == Role.COORDINATOR

    def _is_first_worker(self) -> bool:
        """
        whether current process is worker of rank 0
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._role == Role.WORKER and self._current_id == 0

    def _worker_index(self) -> int:
        """
        get index of current worker
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._current_id

    def _server_index(self) -> int:
        """
        get index of current server
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._current_id

    def _role_id(self) -> int:
        """
        get index of current node
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._current_id

    def _worker_num(self) -> int:
        """
        return the current number of worker
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._trainers_num

    def _server_num(self) -> int:
        """
        return the current number of server
        """
        if not self._role_is_generated:
            self._generate_role()
        return (
            len(self._get_pserver_endpoints())
            if self._get_pserver_endpoints() is not None
            else 0
        )

    def _node_num(self) -> int:
        """
        return the training node number
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._nodes_num

    def _get_node_num(self) -> int:
        """
        return the training node number
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._nodes_num

    def _get_local_rank(self) -> str | None:
        if not self._role_is_generated:
            self._generate_role()
        return self._local_rank

    def _get_local_device_ids(self) -> str | None:
        if not self._role_is_generated:
            self._generate_role()
        return self._local_device_ids

    def _get_world_device_ids(self) -> str | None:
        if not self._role_is_generated:
            self._generate_role()
        return self._world_device_ids

    def _get_trainer_endpoints(self) -> list[str]:
        """
        get endpoint of all trainers
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._worker_endpoints

    def _get_trainer_endpoint(self) -> str:
        if not self._role_is_generated:
            self._generate_role()
        assert self._role == Role.WORKER, (
            "get_trainer_endpoint should be called by trainer"
        )
        return self._cur_endpoint

    def _get_heter_worker_endpoints(self) -> list[str]:
        """
        Returns:
            string: all heter_trainers'endpoints
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._heter_trainer_endpoints != [], (
            "Heter Worker Endpoints Not initialized"
        )
        return self._heter_trainer_endpoints

    def _get_heter_worker_endpoint(self) -> str:
        """
        Returns:
            str: corresponding heter_trainer's endpoint
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role == Role.HETER_WORKER, (
            "_get_heter_worker_endpoint should be invoked by heter worker"
        )
        return self._cur_endpoint

    def _get_pserver_endpoints(self) -> list[str]:
        """
        get endpoint of all pservers
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._server_endpoints

    def _get_coordinator_endpoints(self) -> list[str]:
        if not self._role_is_generated:
            self._generate_role()
        return self._coordinator_endpoints

    def _get_previous_trainers(self):
        """
        invoked by heter worker
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role in (
            Role.WORKER,
            Role.HETER_WORKER,
        ), "_get_previous_trainers should be invoked by trainer or heter worker"
        return self._previous_heter_trainer_endpoints

    def _get_next_trainers(self) -> list[str]:
        """
        invoked by heter worker
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role in (
            Role.WORKER,
            Role.HETER_WORKER,
        ), "_get_next_trainers should be invoked by trainer or heter worker"
        return self._next_heter_trainer_endpoints

    def _is_non_distributed(self) -> bool:
        """
        Return True if indispensable environment for fleetrun is not found
        (use python-run to launch fleet-code directly)
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._non_distributed

    def _heter_worker_num(self) -> int:
        """
        get heter worker nums
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._heter_trainers_num

    def _is_heter_worker(self) -> bool:
        """
        whether current process is heter worker
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._role == Role.HETER_WORKER

    def _ps_env(self) -> None:  # each role will execute it
        # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
        # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
        self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)

        if self._server_endpoints is None:
            # back to non_distributed execution.
            self._server_endpoints = ""
            self._trainers_num = 1
            self._role = Role.WORKER
            self._current_id = 0
            self._nodes_num = 1
            self._heter_trainers_num = 0
            self._heter_trainer_endpoints = None
            self._non_distributed = True
            return

        self._server_endpoints = self._server_endpoints.split(",")

        self._worker_endpoints = getenv_or_backup(
            "PADDLE_TRAINER_ENDPOINTS", None
        )
        if self._worker_endpoints is not None:
            self._worker_endpoints = self._worker_endpoints.split(",")
        else:
            self._worker_endpoints = []

        self._coordinator_endpoints = os.getenv(
            "PADDLE_COORDINATOR_ENDPOINTS", ""
        )
        if self._coordinator_endpoints == "":
            print("fl-ps > coordinator address is null!")
        else:
            self._with_coordinator = True
            self._coordinator_endpoints = self._coordinator_endpoints.split(",")

        trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
        if trainers_num is None:
            raise ValueError(
                "Can not find PADDLE_TRAINERS_NUM, please check your environment."
            )
        trainers_num = int(trainers_num)

        training_role = os.getenv("TRAINING_ROLE", None)
        if training_role is None:
            raise ValueError(
                "Can not find TRAINING_ROLE, please check your environment."
            )

        if training_role not in [
            "TRAINER",
            "PSERVER",
            "HETER_TRAINER",
            "COORDINATOR",
        ]:
            raise ValueError(
                f"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER or COORDINATOR, but get {training_role}, please check your environment."
            )

        # For Heter Parameter Server env setting
        next_heter_trainer_eplist = os.getenv(
            "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST", ""
        )
        previous_heter_trainer_eplist = os.getenv(
            "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST", ""
        )
        all_heter_trainer_eplist = os.getenv(
            "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST", ""
        )

        if all_heter_trainer_eplist != "":
            self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
            self._is_heter_parameter_server_mode = True
            self._heter_trainers_num = len(self._heter_trainer_endpoints)

            if previous_heter_trainer_eplist == "":
                assert training_role in (
                    "TRAINER",
                    "PSERVER",
                ), "training_role should be trainer or pserver"
            else:
                try:
                    self._previous_heter_trainer_endpoints = (
                        previous_heter_trainer_eplist.split(",")
                    )
                except:
                    raise ValueError(
                        "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
                    )

            if next_heter_trainer_eplist == "":
                assert training_role in (
                    "HETER_TRAINER",
                    "PSERVER",
                ), "training_role should be heter trainer or pserver"
            else:
                try:
                    self._next_heter_trainer_endpoints = (
                        next_heter_trainer_eplist.split(",")
                    )
                except:
                    raise ValueError(
                        "Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
                    )

        else:
            self._is_heter_parameter_server_mode = False
            self._heter_trainers_num = 0

        if training_role == "TRAINER":
            role = Role.WORKER
            current_id = os.getenv("PADDLE_TRAINER_ID", None)
            if current_id is None:
                raise ValueError(
                    "Can not find PADDLE_TRAINER_ID, please check your environment."
                )
            current_id = int(current_id)
            if self._is_heter_parameter_server_mode:
                self._stage_id = os.getenv("STAGE_ID", None)
                if self._stage_id is None:
                    raise ValueError(
                        "Can not find STAGE_ID, please check your environment."
                    )
                self._stage_id = int(self._stage_id)
                self._stage_num = os.getenv("STAGE_NUM", None)
                if self._stage_num is None:
                    raise ValueError(
                        "Can not find STAGE_NUM, please check your environment."
                    )
                self._stage_num = int(self._stage_num)
                self._stage_trainers = os.getenv(
                    "PADDLE_STAGE_TRAINERS_NUM", None
                )
                if self._stage_trainers is None:
                    raise ValueError(
                        "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
                    )
                self._stage_trainers = tuple(
                    [int(x) for x in re.findall(r'\d+', self._stage_trainers)]
                )
            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port is None:
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment."
                )
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip is None:
                raise ValueError(
                    "Can not find POD_IP, please check your environment."
                )
            curr_endpoint = ":".join([cur_ip, cur_port])
            self._cur_endpoint = curr_endpoint
        elif training_role == "COORDINATOR":
            print(">>> curr node is coordinator!")
            role = Role.COORDINATOR
            current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        elif training_role == "PSERVER":
            role = Role.SERVER
            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port is None:
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment."
                )
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip is None:
                raise ValueError(
                    "Can not find POD_IP, please check your environment."
                )
            curr_endpoint = ":".join([cur_ip, cur_port])
            self._cur_endpoint = curr_endpoint
            current_id = self._server_endpoints.index(self._cur_endpoint)
        elif training_role == "HETER_TRAINER":
            role = Role.HETER_WORKER
            self._stage_id = os.getenv("STAGE_ID", None)
            if self._stage_id is None:
                raise ValueError(
                    "Can not find STAGE_ID, please check your environment."
                )
            self._stage_id = int(self._stage_id)
            self._stage_num = os.getenv("STAGE_NUM", None)
            if self._stage_num is None:
                raise ValueError(
                    "Can not find STAGE_NUM, please check your environment."
                )
            self._stage_num = int(self._stage_num)

            self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM", None)
            if self._stage_trainers is None:
                raise ValueError(
                    "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
                )
            self._stage_trainers = tuple(
                [int(x) for x in re.findall(r'\d+', self._stage_trainers)]
            )

            self._heter_trainer_device_type = os.getenv(
                "HETER_DEVICE_TYPE", None
            )
            if self._heter_trainer_device_type is None:
                raise ValueError(
                    "Can not find HETER_DEVICE_TYPE, please check your environment."
                )
            assert self._heter_trainer_device_type in (
                "cpu",
                "gpu",
                "xpu",
            ), "HETER_DEVICE_TYPE should be cpu,gpu or xpu"
            if self._heter_trainer_device_type == "gpu":
                heter_device_id = os.getenv("FLAGS_selected_gpus", "0")
                self._heter_trainer_device = ":".join(
                    (self._heter_trainer_device_type, heter_device_id)
                )
            if self._heter_trainer_device == "xpu":
                heter_device_id = os.getenv("FLAGS_selected_xpus", "0")
                self._heter_trainer_device = ":".join(
                    (self._heter_trainer_device_type, heter_device_id)
                )

            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port is None:
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment."
                )
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip is None:
                raise ValueError(
                    "Can not find POD_IP, please check your environment."
                )
            curr_endpoint = ":".join([cur_ip, cur_port])
            self._cur_endpoint = curr_endpoint
            current_id = (
                all_heter_trainer_eplist.split(",").index(curr_endpoint)
                + trainers_num
            )

        self._trainers_num = trainers_num
        self._role = role
        self._current_id = current_id
        self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

    def _collective_env(self) -> None:
        self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
        assert self._training_role == "TRAINER"
        self._role = Role.WORKER
        self._worker_endpoints = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS")
        self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
        if self._worker_endpoints is None:
            # back to non_distributed execution.
            self._worker_endpoints = "127.0.0.1:6170"
            self._cur_endpoint = self._worker_endpoints
            self._non_distributed = True
        self._worker_endpoints = self._worker_endpoints.split(",")
        self._trainers_num = len(self._worker_endpoints)
        auto_tuner = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG", None)
        if auto_tuner is not None:
            trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
            self._trainers_num = int(trainers_num)
        self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})
        self._local_rank = os.getenv("PADDLE_RANK_IN_NODE")
        self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
        self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS")

    def _gloo_init(self) -> None:
        # PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier
        use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
        if use_gloo not in [1, 2]:
            return

        # PADDLE_GLOO_RENDEZVOUS 1: HDFS 2: FILE 3: HTTP
        rendezvous_type = int(os.getenv("PADDLE_GLOO_RENDEZVOUS", "0"))
        prefix = os.getenv("SYS_JOB_ID", "")
        if rendezvous_type not in [
            Gloo.RENDEZVOUS.HDFS,
            Gloo.RENDEZVOUS.HTTP,
            Gloo.RENDEZVOUS.FILE,
        ]:
            raise ValueError(self._gloo._err_type)

        need_init_all = True if use_gloo == 2 else False

        if rendezvous_type == Gloo.RENDEZVOUS.HDFS:
            dfs_name = os.getenv("PADDLE_GLOO_FS_NAME", "")
            dfs_ugi = os.getenv("PADDLE_GLOO_FS_UGI", "")
            dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "")
            kwargs = {
                "dfs.name": dfs_name,
                "dfs.ugi": dfs_ugi,
                "dfs.path": dfs_path,
                "store.prefix": prefix,
            }
        elif rendezvous_type == Gloo.RENDEZVOUS.HTTP:
            start_http_server = False
            manager = Manager()
            http_server_d = manager.dict()
            http_server_d["running"] = False
            if self._is_collective:
                ep_rank_0 = self._worker_endpoints[0]
                if self._is_first_worker():
                    start_http_server = True
            else:
                ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
                if self._is_server() and self._server_index() == 0:
                    start_http_server = True
            ip, port = ep_rank_0.split(':')
            kwargs = {
                "http.host": ip,
                "http.port": port,
                "store.prefix": prefix,
                'start_http_server': start_http_server,
                'http_server_d': http_server_d,
            }
        else:
            dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "")
            kwargs = {
                "dfs.path": dfs_path,
                "store.prefix": prefix,
            }

        if rendezvous_type == Gloo.RENDEZVOUS.HDFS:
            type = "HDFS"
        elif rendezvous_type == Gloo.RENDEZVOUS.HTTP:
            type = "HTTP"
        else:
            type = "FILE"
        print(
            f"Gloo init with {type}: need_init_all: {need_init_all}, args: {kwargs}"
        )

        self._gloo.init(
            rendezvous=rendezvous_type,
            role=self._role,
            role_id=self._role_id(),
            worker_num=self._worker_num(),
            server_num=self._server_num(),
            need_init_all=need_init_all,
            kwargs=kwargs,
        )

        if rendezvous_type == Gloo.RENDEZVOUS.HTTP:
            http_server_d['running'] = False

    def _generate_role(self) -> None:
        """
        generate role for role maker
        """
        if not self._role_is_generated:
            if not self._is_collective:
                self._ps_env()
            else:
                self._collective_env()
            self._role_is_generated = True
            if not paddle.in_dynamic_mode():
                self._gloo_init()


class UserDefinedRoleMaker(PaddleCloudRoleMaker):
    """
    UserDefinedRoleMaker is an interface for distributed configuration initialization based on obtaining distributed related information from user-defined parameters.

    Examples:
        .. code-block:: python

            >>> import paddle.distributed.fleet as fleet
            >>> from paddle.distributed.fleet.base.role_maker import Role

            >>> fleet.UserDefinedRoleMaker(
            ...     current_id=0,
            ...     role=Role.SERVER,
            ...     worker_num=2,
            ...     server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
    """

    def __init__(
        self,
        is_collective: bool = False,
        init_gloo: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            is_collective=is_collective, init_gloo=init_gloo, **kwargs
        )
        self._init_gloo = init_gloo

    def _user_defined_ps_env(self) -> None:
        self._server_endpoints = self._kwargs.get("server_endpoints")
        self._worker_endpoints = self._kwargs.get("worker_endpoints", [])
        self._trainers_num = self._kwargs.get("worker_num", 0)

        if self._trainers_num == 0:
            assert len(self._worker_endpoints) > 0
            self._trainers_num = len(self._worker_endpoints)

        self._role = self._kwargs.get("role")
        self._current_id = self._kwargs.get("current_id")

        if (
            self._role == Role.WORKER
            and len(self._worker_endpoints) > self._current_id
        ):
            self._cur_endpoint = self._worker_endpoints[self._current_id]
        elif self._role == Role.SERVER:
            self._cur_endpoint = self._server_endpoints[self._current_id]
        self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

    def _user_defined_collective_env(self) -> None:
        self._worker_endpoints = self._kwargs.get("worker_endpoints")
        self._current_id = self._kwargs.get("current_id")
        self._trainers_num = len(self._worker_endpoints)
        self._training_role = Role.WORKER
        self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

    def _generate_role(self) -> None:
        """
        generate role for role maker
        """
        if not self._role_is_generated:
            if not self._is_collective:
                self._user_defined_ps_env()
            else:
                self._user_defined_collective_env()
            self._role_is_generated = True
