#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os

from ...strategy import Strategy

_tuning_supported_passes = ["sharding", "recompute"]


def _get_pass_config(strategy, pass_name):
    config = getattr(strategy, pass_name)
    return config


class TuningConfig:
    """
    A uniform config wrap:
    distributed strategy: the user defined configuration for optimization pass
    tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
    """

    def __init__(self, strategy):
        if not isinstance(strategy, Strategy):
            raise TypeError("'strategy' must be object of class `Strategy`.")

        self._tuning_passes_name = set()
        self._dist_strategy = copy.deepcopy(strategy)
        self._mode = None
        self._profile_start_step = None
        self._profile_end_step = None
        self._project_dir = None
        self._max_num_trial = None
        self._early_stop = None
        self._debug = None

        self._initialize()

    @property
    def mode(self):
        return self._mode

    @property
    def profile_start_step(self):
        return self._profile_start_step

    @property
    def profile_end_step(self):
        return self._profile_end_step

    @property
    def project_dir(self):
        return self._project_dir

    @property
    def tuning_passes_name(self):
        return self._tuning_passes_name

    @property
    def max_num_trial(self):
        return self._max_num_trial

    @property
    def early_stop(self):
        return self._early_stop

    @property
    def debug(self):
        return self._debug

    @property
    def dist_strategy(self):
        return self._dist_strategy

    # initialize config with user define value or default value
    def _initialize(self):
        tuning_strategy = self._dist_strategy.tuning

        self._mode = tuning_strategy.get("mode", "PROFILE")
        self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
        self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
        self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
        self._early_stop = tuning_strategy.get("early_stop", None)
        self._debug = tuning_strategy.get("debug", False)

        project_dir = tuning_strategy.get("project_dir", None)
        if not project_dir:
            project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
        self._project_dir = project_dir

        for p in _tuning_supported_passes:
            if (
                getattr(self._dist_strategy, p)
                and _get_pass_config(self._dist_strategy, p).enable_tuning
            ):
                # TODO distinguish different args of each passes
                self._tuning_passes_name.add(p)

                p_strategy = getattr(self._dist_strategy, p)
                self.__dict__[p] = p_strategy

                # # TODO verify the user defined configs
                # tuning_config_for_pass = tuning_strategy.get(p, None)
                # if tuning_config_for_pass:
                #     for k, v in tuning_config_for_pass.items():
                #         self.__dict__[p][k] = v

    # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
    def __getattr__(self, item):
        return getattr(self._dist_strategy, item)
