// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// auto-generated by generate_configs.py

#pragma once

#include "cutlass/gemm_coord.h"

namespace ap {

constexpr int kNumConfigsHalf = 23;
constexpr int kNumConfigsFloat = 13;

template <int SwizzleFactor, bool Batched>
struct SwizzleWrapper {
  using Type =
      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<SwizzleFactor>;
};

// template <int SwizzleFactor>
// struct SwizzleWrapper<SwizzleFactor, true> {
//   using Type =
//       cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
// };

#define AP_AUTOTUNE_half(func, stream, ...)                               \
  {                                                                       \
    using FuncType = decltype(func<0>);                                   \
    static int selected_config_id = -1;                                   \
    static std::vector<std::function<FuncType>> matmul_functions = {      \
        func<0>,  func<1>,  func<2>,  func<3>,  func<4>,  func<5>,        \
        func<6>,  func<7>,  func<8>,  func<9>,  func<10>, func<11>,       \
        func<12>, func<13>, func<14>, func<15>, func<16>, func<17>,       \
        func<18>, func<19>, func<20>, func<21>, func<22>};                \
    if (selected_config_id == -1) {                                       \
      selected_config_id =                                                \
          ap::ProfileBestConfig(matmul_functions, stream, ##__VA_ARGS__); \
    }                                                                     \
    matmul_functions[selected_config_id](__VA_ARGS__);                    \
  }

#define AP_AUTOTUNE_nv_bfloat16(func, stream, ...) \
  AP_AUTOTUNE_half(func, stream, __VA_ARGS__)

#define AP_AUTOTUNE_float(func, stream, ...)                                   \
  {                                                                            \
    using FuncType = decltype(func<0>);                                        \
    static int selected_config_id = -1;                                        \
    static std::vector<std::function<FuncType>> matmul_functions = {func<0>,   \
                                                                    func<1>,   \
                                                                    func<2>,   \
                                                                    func<3>,   \
                                                                    func<4>,   \
                                                                    func<5>,   \
                                                                    func<6>,   \
                                                                    func<7>,   \
                                                                    func<8>,   \
                                                                    func<9>,   \
                                                                    func<10>,  \
                                                                    func<11>,  \
                                                                    func<12>}; \
    if (selected_config_id == -1) {                                            \
      selected_config_id =                                                     \
          ap::ProfileBestConfig(matmul_functions, stream, ##__VA_ARGS__);      \
    }                                                                          \
    matmul_functions[selected_config_id](__VA_ARGS__);                         \
  }

template <typename ElementT, int SwizzleFactor, bool Batched, int Id = 0>
struct GemmTuningConfigs {
  using TShape = cutlass::gemm::GemmShape<256, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 2;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = Id;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 1> {
  using TShape = cutlass::gemm::GemmShape<64, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 1;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 2> {
  using TShape = cutlass::gemm::GemmShape<64, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 2;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 3> {
  using TShape = cutlass::gemm::GemmShape<128, 64, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 3;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 4> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 4;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 5> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 5;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 6> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 6;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 7> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 7;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 8> {
  using TShape = cutlass::gemm::GemmShape<256, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 8;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 9> {
  using TShape = cutlass::gemm::GemmShape<256, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 9;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 10> {
  using TShape = cutlass::gemm::GemmShape<128, 32, 64>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 10;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 11> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 11;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 12> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 12;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 13> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 13;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 14> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 14;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 15> {
  using TShape = cutlass::gemm::GemmShape<32, 64, 64>;
  using WShape = cutlass::gemm::GemmShape<16, 32, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 5;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 15;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 16> {
  using TShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 5;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 16;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 17> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 5;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 17;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 18> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 64>;
  using WShape = cutlass::gemm::GemmShape<64, 64, 64>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 5;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 18;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 19> {
  using TShape = cutlass::gemm::GemmShape<64, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 6;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 19;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 20> {
  using TShape = cutlass::gemm::GemmShape<128, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 6;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 20;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 21> {
  using TShape = cutlass::gemm::GemmShape<128, 32, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 7;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 21;
};

template <typename ElementT, int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<ElementT, SwizzleFactor, Batched, 22> {
  using TShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 16>;
  static constexpr int kNumStages = 10;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 22;
};

// Specialization for float
template <int SwizzleFactor, bool Batched, int Id>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, Id> {
  using TShape = cutlass::gemm::GemmShape<64, 64, 16>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = Id;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 1> {
  using TShape = cutlass::gemm::GemmShape<64, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 1;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 2> {
  using TShape = cutlass::gemm::GemmShape<64, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 2;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 3> {
  using TShape = cutlass::gemm::GemmShape<64, 256, 16>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 3;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 4> {
  using TShape = cutlass::gemm::GemmShape<64, 256, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 4;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 5> {
  using TShape = cutlass::gemm::GemmShape<128, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 5;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 6> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 16>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 6;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 7> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 32>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 7;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 8> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 16>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 8;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 9> {
  using TShape = cutlass::gemm::GemmShape<256, 64, 32>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 32>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 3;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 9;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 10> {
  using TShape = cutlass::gemm::GemmShape<64, 128, 16>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 10;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 11> {
  using TShape = cutlass::gemm::GemmShape<128, 64, 16>;
  using WShape = cutlass::gemm::GemmShape<64, 32, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 11;
};

template <int SwizzleFactor, bool Batched>
struct GemmTuningConfigs<float, SwizzleFactor, Batched, 12> {
  using TShape = cutlass::gemm::GemmShape<128, 128, 16>;
  using WShape = cutlass::gemm::GemmShape<32, 64, 16>;
  using IShape = cutlass::gemm::GemmShape<16, 8, 8>;
  static constexpr int kNumStages = 4;

  using SwizzleThreadBlock =
      typename SwizzleWrapper<SwizzleFactor, Batched>::Type;
  static constexpr int kId = 12;
};

}  // namespace ap
