# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

head_template = """// auto-generated by generate_configs.py

#pragma once

#include "cutlass/gemm_coord.h"

namespace ap {

constexpr int kNumConfigsHalf = ${num_configs_fp16};
constexpr int kNumConfigsFloat = ${num_configs_fp32};

template <int SwizzleFactor, bool Batched> struct SwizzleWrapper {
  using Type =
      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<SwizzleFactor>;
};

// template <int SwizzleFactor>
// struct SwizzleWrapper<SwizzleFactor, true> {
//   using Type =
//       cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
// };
"""

autotune_wrapper_template = """
#define AP_AUTOTUNE_${datatype}(func, stream, ...) { \\
  using FuncType = decltype(func<0>); \\
  static int selected_config_id = -1; \\
  static std::vector<std::function<FuncType>> \\
      matmul_functions = { \\
          ${repeat_functions} \\
          };  \\
  if (selected_config_id == -1) { \\
    selected_config_id = ap::ProfileBestConfig(matmul_functions, stream, ##__VA_ARGS__); \\
  } \\
  matmul_functions[selected_config_id](__VA_ARGS__); \\
}
"""

fp16_config_template_0 = """
template <typename ElementT, int SwizzleFactor, bool Batched, int Id = 0>
struct GemmTuningConfigs {
  using TShape = cutlass::gemm::GemmShape<${tshape}>;
  using WShape = cutlass::gemm::GemmShape<${wshape}>;
  using IShape = cutlass::gemm::GemmShape<${ishape}>;
  static constexpr int kNumStages = ${stages};

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = Id;
};
"""

fp16_config_template = """
template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ${config_id}> {
  using TShape = cutlass::gemm::GemmShape<${tshape}>;
  using WShape = cutlass::gemm::GemmShape<${wshape}>;
  using IShape = cutlass::gemm::GemmShape<${ishape}>;
  static constexpr int kNumStages = ${stages};

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = ${config_id};
};
"""

fp32_config_template_0 = """
// Specialization for float
template <int SwizzleFactor, bool Batched, int Id>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, Id> {
  using TShape = cutlass::gemm::GemmShape<${tshape}>;
  using WShape = cutlass::gemm::GemmShape<${wshape}>;
  using IShape = cutlass::gemm::GemmShape<${ishape}>;
  static constexpr int kNumStages = ${stages};

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = Id;
};
"""

fp32_config_template = """
template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, ${config_id}> {
  using TShape = cutlass::gemm::GemmShape<${tshape}>;
  using WShape = cutlass::gemm::GemmShape<${wshape}>;
  using IShape = cutlass::gemm::GemmShape<${ishape}>;
  static constexpr int kNumStages = ${stages};

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = ${config_id};
};
"""

tail_code_str = """
} // namespace ap
"""


class GemmTuningConfig:
    def __init__(self, tshape, wshape, ishape, stages, level):
        self.tshape = tshape
        self.wshape = wshape
        self.ishape = ishape
        self.stages = stages
        self.level = level

    def __eq__(self, other):
        def check_shape(s1, s2):
            assert len(s1) == len(s2)
            res = True
            for i in range(len(s1)):
                if s1[i] != s2[i]:
                    res = False
                    break
            return res

        res = check_shape(self.tshape, other.tshape)
        res = res and check_shape(self.wshape, other.wshape)
        res = res and check_shape(self.ishape, other.ishape)
        res = res and self.stages == other.stages
        return res


def all_configs_sm80_fp16():
    all_tuning_configs_fp16 = [
        GemmTuningConfig([16, 128, 64], [16, 32, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([32, 128, 64], [32, 32, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([64, 128, 64], [64, 64, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([128, 128, 64], [64, 64, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([128, 128, 64], [128, 32, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([128, 256, 64], [64, 64, 64], [16, 8, 16], 2, 2),
        GemmTuningConfig([256, 128, 64], [64, 64, 64], [16, 8, 16], 2, 1),
        GemmTuningConfig([16, 128, 64], [16, 32, 64], [16, 8, 16], 3, 3),
        GemmTuningConfig([32, 128, 64], [32, 32, 64], [16, 8, 16], 3, 2),
        GemmTuningConfig([64, 128, 64], [32, 64, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([64, 128, 64], [64, 64, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([64, 256, 64], [64, 64, 64], [16, 8, 16], 3, 3),
        GemmTuningConfig([128, 64, 64], [64, 32, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([128, 128, 32], [64, 64, 32], [16, 8, 16], 3, 1),
        GemmTuningConfig([128, 128, 64], [64, 64, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([128, 128, 64], [128, 32, 64], [16, 8, 16], 3, 2),
        GemmTuningConfig([128, 256, 32], [64, 64, 32], [16, 8, 16], 3, 2),
        GemmTuningConfig([128, 256, 64], [64, 64, 64], [16, 8, 16], 3, 2),
        GemmTuningConfig([256, 64, 32], [64, 64, 32], [16, 8, 16], 3, 1),
        GemmTuningConfig([256, 64, 64], [64, 64, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([256, 128, 32], [64, 64, 32], [16, 8, 16], 3, 1),
        GemmTuningConfig([256, 128, 64], [64, 64, 64], [16, 8, 16], 3, 1),
        GemmTuningConfig([16, 128, 64], [16, 32, 64], [16, 8, 16], 4, 3),
        GemmTuningConfig([32, 128, 64], [32, 32, 64], [16, 8, 16], 4, 2),
        GemmTuningConfig([64, 128, 64], [64, 64, 64], [16, 8, 16], 4, 2),
        GemmTuningConfig([64, 256, 32], [64, 64, 32], [16, 8, 16], 4, 2),
        GemmTuningConfig([64, 256, 64], [64, 64, 64], [16, 8, 16], 4, 3),
        GemmTuningConfig([128, 32, 64], [32, 32, 64], [16, 8, 16], 4, 1),
        GemmTuningConfig([128, 128, 32], [64, 64, 32], [16, 8, 16], 4, 1),
        GemmTuningConfig([128, 128, 64], [64, 64, 64], [16, 8, 16], 4, 1),
        GemmTuningConfig([128, 128, 64], [128, 32, 64], [16, 8, 16], 4, 2),
        GemmTuningConfig([256, 64, 64], [64, 64, 64], [16, 8, 16], 4, 1),
        GemmTuningConfig([256, 64, 32], [64, 64, 32], [16, 8, 16], 4, 1),
        GemmTuningConfig([16, 64, 64], [16, 32, 64], [16, 8, 16], 5, 2),
        GemmTuningConfig([16, 128, 64], [16, 32, 64], [16, 8, 16], 5, 3),
        GemmTuningConfig([32, 64, 64], [16, 32, 64], [16, 8, 16], 5, 1),
        GemmTuningConfig([32, 128, 64], [32, 32, 64], [16, 8, 16], 5, 3),
        GemmTuningConfig([64, 64, 64], [32, 32, 64], [16, 8, 16], 5, 1),
        GemmTuningConfig([64, 128, 64], [64, 64, 64], [16, 8, 16], 5, 3),
        GemmTuningConfig([128, 128, 32], [64, 64, 32], [16, 8, 16], 5, 1),
        GemmTuningConfig([128, 128, 64], [64, 64, 64], [16, 8, 16], 5, 1),
        GemmTuningConfig([128, 128, 64], [128, 32, 64], [16, 8, 16], 5, 2),
        GemmTuningConfig([64, 128, 32], [32, 64, 32], [16, 8, 16], 6, 1),
        GemmTuningConfig([128, 64, 32], [64, 32, 32], [16, 8, 16], 6, 1),
        GemmTuningConfig([128, 32, 32], [32, 32, 32], [16, 8, 16], 7, 1),
        GemmTuningConfig([64, 64, 32], [32, 32, 32], [16, 8, 16], 10, 1),
    ]
    return all_tuning_configs_fp16


def all_configs_sm80_fp32():
    all_tuning_configs_fp32 = [
        GemmTuningConfig([64, 64, 16], [32, 32, 16], [16, 8, 8], 3, 1),
        GemmTuningConfig([64, 64, 32], [32, 32, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([64, 128, 32], [32, 64, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([64, 256, 16], [32, 64, 16], [16, 8, 8], 3, 1),
        GemmTuningConfig([64, 256, 32], [32, 64, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([128, 64, 32], [64, 32, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([128, 128, 16], [32, 64, 16], [16, 8, 8], 3, 1),
        GemmTuningConfig([128, 128, 32], [32, 64, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([256, 64, 16], [64, 32, 16], [16, 8, 8], 3, 1),
        GemmTuningConfig([256, 64, 32], [64, 32, 32], [16, 8, 8], 3, 1),
        GemmTuningConfig([64, 128, 16], [32, 64, 16], [16, 8, 8], 4, 1),
        GemmTuningConfig([128, 64, 16], [64, 32, 16], [16, 8, 8], 4, 1),
        GemmTuningConfig([128, 128, 16], [32, 64, 16], [16, 8, 8], 4, 1),
    ]
    return all_tuning_configs_fp32


def generate_autotune_wrapper(datatype, num_configs):
    repeat_func_strs = []
    for i in range(num_configs):
        repeat_func_strs.append(f"func<{i}>")
    code_str = autotune_wrapper_template.replace(
        "${datatype}", datatype
    ).replace("${repeat_functions}", ", \\\n          ".join(repeat_func_strs))
    return code_str


def get_configs(all_configs_list, level=3):
    consigs_list = []
    for i in range(len(all_configs_list)):
        if all_configs_list[i].level <= level:
            already_have = False
            for config in consigs_list:
                if config == all_configs_list[i]:
                    print(f"-- The {i}-th config is repeat.")
                    already_have = True
                    break
            if not already_have:
                consigs_list.append(all_configs_list[i])
    return consigs_list


def generate_configs(configs_list, config_template_0, config_template):
    code_str = ""
    config_id = 0
    for config in configs_list:
        if config_id == 0:
            config_code_str = (
                config_template_0.replace(
                    "${tshape}", ", ".join(map(str, config.tshape))
                )
                .replace("${wshape}", ", ".join(map(str, config.wshape)))
                .replace("${ishape}", ", ".join(map(str, config.ishape)))
                .replace("${stages}", str(config.stages))
            )
        else:
            config_code_str = (
                config_template.replace(
                    "${tshape}", ", ".join(map(str, config.tshape))
                )
                .replace("${wshape}", ", ".join(map(str, config.wshape)))
                .replace("${ishape}", ", ".join(map(str, config.ishape)))
                .replace("${stages}", str(config.stages))
                .replace("${config_id}", str(config_id))
            )
        code_str += config_code_str
        config_id += 1
    return len(configs_list), code_str


def main():
    level = 1
    num_fp16_configs, fp16_configs_code_str = generate_configs(
        configs_list=get_configs(all_configs_sm80_fp16(), level=level),
        config_template_0=fp16_config_template_0,
        config_template=fp16_config_template,
    )
    num_fp32_configs, fp32_configs_code_str = generate_configs(
        configs_list=get_configs(all_configs_sm80_fp32(), level=level),
        config_template_0=fp32_config_template_0,
        config_template=fp32_config_template,
    )
    print(
        f"-- Total {num_fp16_configs} fp16 configs, {num_fp32_configs} fp32 configs"
    )
    head_code_str = head_template.replace(
        "${num_configs_fp16}", str(num_fp16_configs)
    ).replace("${num_configs_fp32}", str(num_fp32_configs))
    fp16_autotune_wrapper_code_str = generate_autotune_wrapper(
        "half", num_fp16_configs
    )
    fp32_autotune_wrapper_code_str = generate_autotune_wrapper(
        "float", num_fp32_configs
    )
    with open("all_tuning_configs.h", "w") as f:
        f.write(head_code_str)
        f.write(fp16_autotune_wrapper_code_str)
        f.write(fp32_autotune_wrapper_code_str)
        f.write(fp16_configs_code_str)
        f.write(fp32_configs_code_str)
        f.write(tail_code_str)


if __name__ == "__main__":
    main()
