"""CSATv2

A frequency-domain vision model using DCT transforms with spatial attention.

Paper: TBD

This model created by members of MLPA Lab. Welcome feedback and suggestion, questions.
gusdlf93@naver.com
juno.demie.oh@gmail.com

Refined for timm by Ross Wightman
"""
import math
import warnings
from functools import partial, reduce
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.layers import trunc_normal_, DropPath, Mlp, LayerNorm2d, Attention, NormMlpClassifierHead, LayerScale, LayerScale2d
from timm.layers.grn import GlobalResponseNorm
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._manipulate import checkpoint, checkpoint_seq
from ._registry import register_model, generate_default_cfgs

__all__ = ['CSATv2', 'csatv2']

# DCT frequency normalization statistics (Y, Cb, Cr channels x 64 coefficients)
_DCT_MEAN = (
    (932.42657, -0.00260, 0.33415, -0.02840, 0.00003, -0.02792, -0.00183, 0.00006,
     0.00032, 0.03402, -0.00571, 0.00020, 0.00006, -0.00038, -0.00558, -0.00116,
     -0.00000, -0.00047, -0.00008, -0.00030, 0.00942, 0.00161, -0.00009, -0.00006,
     -0.00014, -0.00035, 0.00001, -0.00220, 0.00033, -0.00002, -0.00003, -0.00020,
     0.00007, -0.00000, 0.00005, 0.00293, -0.00004, 0.00006, 0.00019, 0.00004,
     0.00006, -0.00015, -0.00002, 0.00007, 0.00010, -0.00004, 0.00008, 0.00000,
     0.00008, -0.00001, 0.00015, 0.00002, 0.00007, 0.00003, 0.00004, -0.00001,
     0.00004, -0.00000, 0.00002, -0.00000, -0.00008, -0.00000, -0.00003, 0.00003),
    (962.34735, -0.00428, 0.09835, 0.00152, -0.00009, 0.00312, -0.00141, -0.00001,
     -0.00013, 0.01050, 0.00065, 0.00006, -0.00000, 0.00003, 0.00264, 0.00000,
     0.00001, 0.00007, -0.00006, 0.00003, 0.00341, 0.00163, 0.00004, 0.00003,
     -0.00001, 0.00008, -0.00000, 0.00090, 0.00018, -0.00006, -0.00001, 0.00007,
     -0.00003, -0.00001, 0.00006, 0.00084, -0.00000, -0.00001, 0.00000, 0.00004,
     -0.00001, -0.00002, 0.00000, 0.00001, 0.00002, 0.00001, 0.00004, 0.00011,
     0.00000, -0.00003, 0.00011, -0.00002, 0.00001, 0.00001, 0.00001, 0.00001,
     -0.00007, -0.00003, 0.00001, 0.00000, 0.00001, 0.00002, 0.00001, 0.00000),
    (1053.16101, -0.00213, -0.09207, 0.00186, 0.00013, 0.00034, -0.00119, 0.00002,
     0.00011, -0.00984, 0.00046, -0.00007, -0.00001, -0.00005, 0.00180, 0.00042,
     0.00002, -0.00010, 0.00004, 0.00003, -0.00301, 0.00125, -0.00002, -0.00003,
     -0.00001, -0.00001, -0.00001, 0.00056, 0.00021, 0.00001, -0.00001, 0.00002,
     -0.00001, -0.00001, 0.00005, -0.00070, -0.00002, -0.00002, 0.00005, -0.00004,
     -0.00000, 0.00002, -0.00002, 0.00001, 0.00000, -0.00003, 0.00004, 0.00007,
     0.00001, 0.00000, 0.00013, -0.00000, 0.00000, 0.00002, -0.00000, -0.00001,
     -0.00004, -0.00003, 0.00000, 0.00001, -0.00001, 0.00001, -0.00000, 0.00000),
)

_DCT_VAR = (
    (270372.37500, 6287.10645, 5974.94043, 1653.10889, 1463.91748, 1832.58997, 755.92468, 692.41528,
     648.57184, 641.46881, 285.79288, 301.62100, 380.43405, 349.84027, 374.15891, 190.30960,
     190.76746, 221.64578, 200.82646, 145.87979, 126.92046, 62.14622, 67.75562, 102.42001,
     129.74922, 130.04631, 103.12189, 97.76417, 53.17402, 54.81048, 73.48712, 81.04342,
     69.35100, 49.06024, 33.96053, 37.03279, 20.48858, 24.94830, 33.90822, 44.54912,
     47.56363, 40.03160, 30.43313, 22.63899, 26.53739, 26.57114, 21.84404, 17.41557,
     15.18253, 10.69678, 11.24111, 12.97229, 15.08971, 15.31646, 8.90409, 7.44213,
     6.66096, 6.97719, 4.17834, 3.83882, 4.51073, 2.36646, 2.41363, 1.48266),
    (18839.21094, 321.70932, 300.15259, 77.47830, 76.02293, 89.04748, 33.99642, 34.74807,
     32.12333, 28.19588, 12.04675, 14.26871, 18.45779, 16.59588, 15.67892, 7.37718,
     8.56312, 10.28946, 9.41013, 6.69090, 5.16453, 2.55186, 3.03073, 4.66765,
     5.85418, 5.74644, 4.33702, 3.66948, 1.95107, 2.26034, 3.06380, 3.50705,
     3.06359, 2.19284, 1.54454, 1.57860, 0.97078, 1.13941, 1.48653, 1.89996,
     1.95544, 1.64950, 1.24754, 0.93677, 1.09267, 1.09516, 0.94163, 0.78966,
     0.72489, 0.50841, 0.50909, 0.55664, 0.63111, 0.64125, 0.38847, 0.33378,
     0.30918, 0.33463, 0.20875, 0.19298, 0.21903, 0.13380, 0.13444, 0.09554),
    (17127.39844, 292.81421, 271.45209, 66.64056, 63.60253, 76.35437, 28.06587, 27.84831,
     25.96656, 23.60370, 9.99173, 11.34992, 14.46955, 12.92553, 12.69353, 5.91537,
     6.60187, 7.90891, 7.32825, 5.32785, 4.29660, 2.13459, 2.44135, 3.66021,
     4.50335, 4.38959, 3.34888, 2.97181, 1.60633, 1.77010, 2.35118, 2.69018,
     2.38189, 1.74596, 1.26014, 1.31684, 0.79327, 0.92046, 1.17670, 1.47609,
     1.50914, 1.28725, 0.99898, 0.74832, 0.85736, 0.85800, 0.74663, 0.63508,
     0.58748, 0.41098, 0.41121, 0.44663, 0.50277, 0.51519, 0.31729, 0.27336,
     0.25399, 0.27241, 0.17353, 0.16255, 0.18440, 0.11602, 0.11511, 0.08450),
)


def _zigzag_permutation(rows: int, cols: int) -> List[int]:
    """Generate zigzag scan order for DCT coefficients."""
    idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist()
    dia = [[] for _ in range(rows + cols - 1)]
    zigzag = []
    for i in range(rows):
        for j in range(cols):
            s = i + j
            if s % 2 == 0:
                dia[s].insert(0, idx_matrix[i][j])
            else:
                dia[s].append(idx_matrix[i][j])
    for d in dia:
        zigzag.extend(d)
    return zigzag


def _dct_kernel_type_2(
        kernel_size: int,
        orthonormal: bool,
        device=None,
        dtype=None,
) -> torch.Tensor:
    """Generate Type-II DCT kernel matrix."""
    dd = dict(device=device, dtype=dtype)
    x = torch.eye(kernel_size, **dd)
    v = x.clone().contiguous().view(-1, kernel_size)
    v = torch.cat([v, v.flip([1])], dim=-1)
    v = torch.fft.fft(v, dim=-1)[:, :kernel_size]
    k = (
        torch.tensor(-1j, device=device, dtype=torch.complex64) * torch.pi
        * torch.arange(kernel_size, device=device, dtype=torch.long)[None, :]
    )
    k = torch.exp(k / (kernel_size * 2))
    v = v * k
    v = v.real
    if orthonormal:
        v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **dd))
        v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **dd))
    v = v.contiguous().view(*x.shape)
    return v


def _dct_kernel_type_3(
        kernel_size: int,
        orthonormal: bool,
        device=None,
        dtype=None,
) -> torch.Tensor:
    """Generate Type-III DCT kernel matrix (inverse of Type-II)."""
    return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype))


class Dct1d(nn.Module):
    """1D Discrete Cosine Transform layer."""

    def __init__(
            self,
            kernel_size: int,
            kernel_type: int = 2,
            orthonormal: bool = True,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3}
        dct_weights = kernel[f'{kernel_type}'](kernel_size, orthonormal, **dd).T
        self.register_buffer('weights', dct_weights.contiguous())
        self.register_parameter('bias', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.linear(x, self.weights, self.bias)


class Dct2d(nn.Module):
    """2D Discrete Cosine Transform layer."""

    def __init__(
            self,
            kernel_size: int,
            kernel_type: int = 2,
            orthonormal: bool = True,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.transform = Dct1d(kernel_size, kernel_type, orthonormal, **dd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2)


def _split_out_chs(out_chs: int, ratio=(24, 4, 4)):
    # reduce ratio to smallest integers (24,4,4) -> (6,1,1)
    g = reduce(math.gcd, ratio)
    r = tuple(x // g for x in ratio)
    denom = sum(r)

    assert out_chs % denom == 0 and out_chs >= denom, (
        f"out_chs={out_chs} can't be split into Y/Cb/Cr with ratio {ratio} "
        f"(reduced {r}); out_chs must be a multiple of {denom}."
    )

    unit = out_chs // denom
    y, cb, cr = (ri * unit for ri in r)
    assert y + cb + cr == out_chs and min(y, cb, cr) > 0
    return y, cb, cr


class LearnableDct2d(nn.Module):
    """Learnable 2D DCT stem with RGB to YCbCr conversion and frequency selection."""

    def __init__(
            self,
            kernel_size: int,
            kernel_type: int = 2,
            orthonormal: bool = True,
            out_chs: int = 32,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.k = kernel_size
        self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
        self.permutation = _zigzag_permutation(kernel_size, kernel_size)

        y_ch, cb_ch, cr_ch = _split_out_chs(out_chs, ratio=(24, 4, 4))
        self.conv_y  = nn.Conv2d(kernel_size ** 2, y_ch,  kernel_size=1, padding=0, **dd)
        self.conv_cb = nn.Conv2d(kernel_size ** 2, cb_ch, kernel_size=1, padding=0, **dd)
        self.conv_cr = nn.Conv2d(kernel_size ** 2, cr_ch, kernel_size=1, padding=0, **dd)

        # Register empty buffers for DCT normalization statistics
        self.register_buffer('mean', torch.empty(3, 64, device=device, dtype=dtype), persistent=False)
        self.register_buffer('var', torch.empty(3, 64, device=device, dtype=dtype), persistent=False)
        # Shape (3, 1, 1) for BCHW broadcasting
        self.register_buffer('imagenet_mean', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)
        self.register_buffer('imagenet_std', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)

        # TODO: skip init when on meta device when safe to do so
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize buffers."""
        self._init_buffers()

    def _init_buffers(self) -> None:
        """Compute and fill non-persistent buffer values."""
        self.mean.copy_(torch.tensor(_DCT_MEAN))
        self.var.copy_(torch.tensor(_DCT_VAR))
        self.imagenet_mean.copy_(torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1))
        self.imagenet_std.copy_(torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1))

    def init_non_persistent_buffers(self) -> None:
        """Initialize non-persistent buffers."""
        self._init_buffers()

    def _denormalize(self, x: torch.Tensor) -> torch.Tensor:
        """Convert from ImageNet normalized to [0, 255] range."""
        return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255

    def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor:
        """Convert RGB to YCbCr color space (BCHW input/output)."""
        r, g, b = x[:, 0], x[:, 1], x[:, 2]
        y = r * 0.299 + g * 0.587 + b * 0.114
        cb = 0.564 * (b - y) + 128
        cr = 0.713 * (r - y) + 128
        return torch.stack([y, cb, cr], dim=1)

    def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize DCT coefficients using precomputed statistics."""
        std = self.var ** 0.5 + 1e-8
        return (x - self.mean) / std

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape
        x = self._denormalize(x)
        x = self._rgb_to_ycbcr(x)
        # Extract non-overlapping k x k patches
        x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k)  # (B, C, H//k, k, W//k, k)
        x = x.permute(0, 2, 4, 1, 3, 5)  # (B, H//k, W//k, C, k, k)
        x = self.transform(x)
        x = x.reshape(-1, c, self.k * self.k)
        x = x[:, :, self.permutation]
        x = self._frequency_normalize(x)
        x = x.reshape(b, h // self.k, w // self.k, c, -1)
        x = x.permute(0, 3, 4, 1, 2).contiguous()
        x_y = self.conv_y(x[:, 0])
        x_cb = self.conv_cb(x[:, 1])
        x_cr = self.conv_cr(x[:, 2])
        return torch.cat([x_y, x_cb, x_cr], dim=1)


class Dct2dStats(nn.Module):
    """Utility module to compute DCT coefficient statistics."""

    def __init__(
            self,
            kernel_size: int,
            kernel_type: int = 2,
            orthonormal: bool = True,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.k = kernel_size
        self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
        self.permutation = _zigzag_permutation(kernel_size, kernel_size)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        b, c, h, w = x.shape
        # Extract non-overlapping k x k patches
        x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k)  # (B, C, H//k, k, W//k, k)
        x = x.permute(0, 2, 4, 1, 3, 5)  # (B, H//k, W//k, C, k, k)
        x = self.transform(x)
        x = x.reshape(-1, c, self.k * self.k)
        x = x[:, :, self.permutation]
        x = x.reshape(b * (h // self.k) * (w // self.k), c, -1)

        mean_list = torch.zeros([3, 64])
        var_list = torch.zeros([3, 64])
        for i in range(3):
            mean_list[i] = torch.mean(x[:, i], dim=0)
            var_list[i] = torch.var(x[:, i], dim=0)
        return mean_list, var_list


class Block(nn.Module):
    """ConvNeXt-style block with spatial attention."""

    def __init__(
            self,
            dim: int,
            drop_path: float = 0.,
            ls_init_value: Optional[float] = None,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, **dd)
        self.norm = nn.LayerNorm(dim, eps=1e-6, **dd)
        self.pwconv1 = nn.Linear(dim, 4 * dim, **dd)
        self.act = nn.GELU()
        self.grn = GlobalResponseNorm(4 * dim, channels_last=True, **dd)
        self.pwconv2 = nn.Linear(4 * dim, dim, **dd)
        self.ls = LayerScale2d(dim, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.attn = SpatialAttention(**dd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)

        attn = self.attn(x)
        attn = F.interpolate(attn, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = x * attn
        x = self.ls(x)

        return shortcut + self.drop_path(x)


class SpatialTransformerBlock(nn.Module):
    """Lightweight transformer block for spatial attention (1-channel, 7x7 grid).

    This is a simplified transformer with single-head, 1-dim attention over spatial
    positions. Used inside SpatialAttention where input is 1 channel at 7x7 resolution.
    """

    def __init__(
            self,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        # Single-head attention with 1-dim q/k/v (no output projection needed)
        self.pos_embed = PosConv(in_chans=1, **dd)
        self.norm1 = nn.LayerNorm(1, **dd)
        self.qkv = nn.Linear(1, 3, bias=False, **dd)

        # Feedforward: 1 -> 4 -> 1
        self.norm2 = nn.LayerNorm(1, **dd)
        self.mlp = Mlp(1, 4, 1, act_layer=nn.GELU, **dd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape

        # Attention block
        shortcut = x
        x_t = x.flatten(2).transpose(1, 2)  # (B, N, 1)
        x_t = self.norm1(x_t)
        x_t = self.pos_embed(x_t, (H, W))

        # Simple single-head attention with scalar q/k/v
        qkv = self.qkv(x_t)  # (B, N, 3)
        q, k, v = qkv.unbind(-1)  # each (B, N)
        attn = (q @ k.transpose(-1, -2)).softmax(dim=-1)  # (B, N, N)
        x_t = (attn @ v).unsqueeze(-1)  # (B, N, 1)

        x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
        x = shortcut + x_t

        # Feedforward block
        shortcut = x
        x_t = x.flatten(2).transpose(1, 2)
        x_t = self.mlp(self.norm2(x_t))
        x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
        x = shortcut + x_t

        return x


class SpatialAttention(nn.Module):
    """Spatial attention module using channel statistics and transformer."""

    def __init__(
            self,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, **dd)
        self.attn = SpatialTransformerBlock(**dd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_avg = x.mean(dim=1, keepdim=True)
        x_max = x.amax(dim=1, keepdim=True)
        x = torch.cat([x_avg, x_max], dim=1)
        x = self.avgpool(x)
        x = self.conv(x)
        x = self.attn(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with optional downsampling and convolutional position encoding."""

    def __init__(
            self,
            inp: int,
            oup: int,
            num_heads: int = 8,
            attn_head_dim: int = 32,
            downsample: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            drop_path: float = 0.,
            ls_init_value: Optional[float] = None,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        hidden_dim = int(inp * 4)
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False, **dd)
        else:
            self.pool1 = nn.Identity()
            self.pool2 = nn.Identity()
            self.proj = nn.Identity()

        self.pos_embed = PosConv(in_chans=inp, **dd)
        self.norm1 = nn.LayerNorm(inp, **dd)
        self.attn = Attention(
            dim=inp,
            num_heads=num_heads,
            attn_head_dim=attn_head_dim,
            dim_out=oup,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            **dd,
        )
        self.ls1 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.LayerNorm(oup, **dd)
        self.mlp = Mlp(oup, hidden_dim, oup, act_layer=nn.GELU, drop=proj_drop, **dd)
        self.ls2 = LayerScale(oup, init_values=ls_init_value, **dd) if ls_init_value else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.downsample:
            shortcut = self.proj(self.pool1(x))
            x_t = self.pool2(x)
            B, C, H, W = x_t.shape
            x_t = x_t.flatten(2).transpose(1, 2)
            x_t = self.norm1(x_t)
            x_t = self.pos_embed(x_t, (H, W))
            x_t = self.ls1(self.attn(x_t))
            x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
            x = shortcut + self.drop_path1(x_t)
        else:
            B, C, H, W = x.shape
            shortcut = x
            x_t = x.flatten(2).transpose(1, 2)
            x_t = self.norm1(x_t)
            x_t = self.pos_embed(x_t, (H, W))
            x_t = self.ls1(self.attn(x_t))
            x_t = x_t.transpose(1, 2).reshape(B, -1, H, W)
            x = shortcut + self.drop_path1(x_t)

        # MLP block
        B, C, H, W = x.shape
        shortcut = x
        x_t = x.flatten(2).transpose(1, 2)
        x_t = self.ls2(self.mlp(self.norm2(x_t)))
        x_t = x_t.transpose(1, 2).reshape(B, C, H, W)
        x = shortcut + self.drop_path2(x_t)

        return x


class PosConv(nn.Module):
    """Convolutional position encoding."""

    def __init__(
            self,
            in_chans: int,
            device=None,
            dtype=None,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=1, bias=True, groups=in_chans, **dd)

    def forward(self, x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
        B, N, C = x.shape
        H, W = size
        cnn_feat = x.transpose(1, 2).view(B, C, H, W)
        x = self.proj(cnn_feat) + cnn_feat
        return x.flatten(2).transpose(1, 2)


class CSATv2(nn.Module):
    """CSATv2: Frequency-domain vision model with spatial attention.

    A hybrid architecture that processes images in the DCT frequency domain
    with ConvNeXt-style blocks and transformer attention.
    """

    def __init__(
            self,
            num_classes: int = 1000,
            in_chans: int = 3,
            dims: Tuple[int, ...] = (32, 72, 168, 386),
            depths: Tuple[int, ...] = (2, 2, 8, 6),
            transformer_depths: Tuple[int, ...] = (0, 0, 2, 2),
            drop_path_rate: float = 0.0,
            transformer_drop_path: bool = False,
            ls_init_value: Optional[float] = None,
            global_pool: str = 'avg',
            device=None,
            dtype=None,
            **kwargs,
    ) -> None:
        dd = dict(device=device, dtype=dtype)
        super().__init__()
        if in_chans != 3:
            warnings.warn(
                f'CSATv2 is designed for 3-channel RGB input. '
                f'in_chans={in_chans} may not work correctly with the DCT stem.'
            )
        self.num_classes = num_classes
        self.in_chans = in_chans
        self.global_pool = global_pool
        self.grad_checkpointing = False

        self.num_features = dims[-1]
        self.head_hidden_size = self.num_features

        # Build feature_info dynamically
        self.feature_info = [dict(num_chs=dims[0], reduction=8, module='stem_dct')]
        reduction = 8
        for i, dim in enumerate(dims):
            if i > 0:
                reduction *= 2
            self.feature_info.append(dict(num_chs=dim, reduction=reduction, module=f'stages.{i}'))

        # Build drop path rates for all blocks (0 for transformer blocks when transformer_drop_path=False)
        total_blocks = sum(depths) if transformer_drop_path else sum(d - t for d, t in zip(depths, transformer_depths))
        dp_iter = iter(torch.linspace(0, drop_path_rate, total_blocks).tolist())
        dp_rates = []
        for depth, t_depth in zip(depths, transformer_depths):
            dp_rates += [next(dp_iter) for _ in range(depth - t_depth)]
            dp_rates += [next(dp_iter) if transformer_drop_path else 0. for _ in range(t_depth)]

        self.stem_dct = LearnableDct2d(8, out_chs=dims[0], **dd)

        # Build stages dynamically
        dp_iter = iter(dp_rates)
        stages = []
        for i, (dim, depth, t_depth) in enumerate(zip(dims, depths, transformer_depths)):
            layers = (
                # Downsample at start of stage (except first stage)
                ([nn.Conv2d(dims[i - 1], dim, kernel_size=2, stride=2, **dd)] if i > 0 else []) +
                # Conv blocks
                [Block(dim=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(depth - t_depth)] +
                # Transformer blocks at end of stage
                [TransformerBlock(inp=dim, oup=dim, drop_path=next(dp_iter), ls_init_value=ls_init_value, **dd) for _ in range(t_depth)] +
                # Trailing LayerNorm (except last stage)
                ([LayerNorm2d(dim, eps=1e-6, **dd)] if i < len(depths) - 1 else [])
            )
            stages.append(nn.Sequential(*layers))
        self.stages = nn.Sequential(*stages)

        self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd)

        # TODO: skip init when on meta device when safe to do so
        self.init_weights(needs_reset=False)

    def init_weights(self, needs_reset: bool = True):
        self.apply(partial(self._init_weights, needs_reset=needs_reset))

    def _init_weights(self, m: nn.Module, needs_reset: bool = True) -> None:
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif needs_reset and hasattr(m, 'reset_parameters'):
            m.reset_parameters()

    @torch.jit.ignore
    def get_classifier(self) -> nn.Module:
        return self.head.fc

    def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
        self.num_classes = num_classes
        if global_pool is not None:
            self.global_pool = global_pool
        self.head.reset(num_classes, pool_type=global_pool)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable: bool = True) -> None:
        self.grad_checkpointing = enable

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem_dct(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.stages, x)
        else:
            x = self.stages(x)
        return x

    def forward_intermediates(
            self,
            x: torch.Tensor,
            indices: Optional[Union[int, List[int]]] = None,
            norm: bool = False,
            stop_early: bool = False,
            output_fmt: str = 'NCHW',
            intermediates_only: bool = False,
    ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
        """Forward pass returning intermediate features.

        Args:
            x: Input image tensor.
            indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all.
            norm: Apply norm layer to final intermediate (unused, for API compat).
            stop_early: Stop iterating when last desired intermediate is reached.
            output_fmt: Output format, must be 'NCHW'.
            intermediates_only: Only return intermediate features.

        Returns:
            List of intermediate features or tuple of (final features, intermediates).
        """
        assert output_fmt == 'NCHW', 'Output format must be NCHW.'
        intermediates = []
        # 5 feature levels: stem_dct (0) + stages 0-3 (1-4)
        take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)

        x = self.stem_dct(x)
        if 0 in take_indices:
            intermediates.append(x)

        if torch.jit.is_scripting() or not stop_early:
            stages = self.stages
        else:
            # max_index is 0-4, stages are 1-4, so we need max_index stages
            stages = self.stages[:max_index] if max_index > 0 else []

        for feat_idx, stage in enumerate(stages):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(stage, x)
            else:
                x = stage(x)
            if feat_idx + 1 in take_indices:  # +1 because stem is index 0
                intermediates.append(x)

        if intermediates_only:
            return intermediates

        return x, intermediates

    def prune_intermediate_layers(
            self,
            indices: Union[int, List[int]] = 1,
            prune_norm: bool = False,
            prune_head: bool = True,
    ) -> List[int]:
        """Prune layers not required for specified intermediates.

        Args:
            indices: Indices of intermediate layers to keep (0=stem_dct, 1-4=stages).
            prune_norm: Whether to prune the final norm layer.
            prune_head: Whether to prune the classifier head.

        Returns:
            List of indices that were kept.
        """
        # 5 feature levels: stem_dct (0) + stages 0-3 (1-4)
        take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
        # max_index is 0-4, stages are 1-4, so we keep max_index stages
        self.stages = self.stages[:max_index] if max_index > 0 else nn.Sequential()

        if prune_norm:
            self.head.norm = nn.Identity()
        if prune_head:
            self.reset_classifier(0, '')

        return take_indices

    def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
        return self.head(x, pre_logits=pre_logits)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.forward_features(x)
        return self.forward_head(x)


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 512, 512), 'pool_size': (8, 8),
        'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
        'interpolation': 'bilinear', 'crop_pct': 1.0,
        'classifier': 'head.fc', 'first_conv': [],
        **kwargs,
    }


default_cfgs = generate_default_cfgs({
    'csatv2.r512_in1k': _cfg(
        hf_hub_id='timm/',
    ),
    'csatv2_21m.sw_r640_in1k': _cfg(
        hf_hub_id='timm/',
        input_size=(3, 640, 640),
        interpolation='bicubic',
    ),
    'csatv2_21m.sw_r512_in1k': _cfg(
        hf_hub_id='timm/',
        pool_size=(10, 10),
        interpolation='bicubic',
    ),
})


def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> dict:
    """Remap original CSATv2 checkpoint to timm format.

    Handles two key structural changes:
    1) Stage naming: stages1/2/3/4 -> stages.0/1/2/3
    2) Downsample position: moved from end of stage N to start of stage N+1
    """
    if "stages.0.0.grn.weight" in state_dict:
        return state_dict  # already in timm format

    import re

    # FIXME this downsample idx is wired to the original 'csatv2' model size
    downsample_idx = {1: 3, 2: 3, 3: 9}  # original stage -> downsample index

    dct_re   = re.compile(r"^dct\.")
    stage_re = re.compile(r"^stages([1-4])\.(\d+)\.(.*)$")
    head_re  = re.compile(r"^head\.")
    norm_re  = re.compile(r"^norm\.")

    def remap_stage(m: re.Match) -> str:
        stage, idx, rest = int(m.group(1)), int(m.group(2)), m.group(3)
        if stage in downsample_idx and idx == downsample_idx[stage]:
            return f"stages.{stage}.0.{rest}"                 # move downsample to next stage @0
        if stage == 1:
            return f"stages.0.{idx}.{rest}"                  # stage1 -> stages.0
        return f"stages.{stage - 1}.{idx + 1}.{rest}"        # stage2-4 -> stages.1-3, shift +1

    out = {}
    for k, v in state_dict.items():
        # dct -> stem_dct, and Y/Cb/Cr conv names
        k = dct_re.sub("stem_dct.", k)
        k = (k.replace(".Y_Conv.",  ".conv_y.")
               .replace(".Cb_Conv.", ".conv_cb.")
               .replace(".Cr_Conv.", ".conv_cr."))

        # stage remap + downsample relocation
        k = stage_re.sub(remap_stage, k)

        # GRN: gamma/beta -> weight/bias (reshape)
        if "grn.gamma" in k:
            k, v = k.replace("grn.gamma", "grn.weight"), v.reshape(-1)
        elif "grn.beta" in k:
            k, v = k.replace("grn.beta", "grn.bias"), v.reshape(-1)

        # FeedForward(nn.Sequential) -> Mlp + norm renames
        if ".ff.net.0." in k:
            k = k.replace(".ff.net.0.", ".mlp.fc1.")
        elif ".ff.net.3." in k:
            k = k.replace(".ff.net.3.", ".mlp.fc2.")
        elif ".ff_norm." in k:
            k = k.replace(".ff_norm.", ".norm2.")
        elif ".attn_norm." in k:
            k = k.replace(".attn_norm.", ".norm1.")

        # attention -> attn (handle nested first)
        if ".attention.attention." in k:
            k = (k.replace(".attention.attention.attn.to_qkv.", ".attn.attn.qkv.")
                   .replace(".attention.attention.attn.",        ".attn.attn.")
                   .replace(".attention.attention.",             ".attn.attn."))
        elif ".attention." in k:
            k = k.replace(".attention.", ".attn.")

        # TransformerBlock attention name remaps
        if ".attn.to_qkv." in k:
            k = k.replace(".attn.to_qkv.", ".attn.qkv.")
        elif ".attn.to_out.0." in k:
            k = k.replace(".attn.to_out.0.", ".attn.proj.")

        # .attn.pos_embed -> .pos_embed (but not SpatialTransformerBlock's .attn.attn.pos_embed)
        if ".attn.pos_embed." in k and ".attn.attn." not in k:
            k = k.replace(".attn.pos_embed.", ".pos_embed.")

        # head -> head.fc, norm -> head.norm (order matters)
        k = head_re.sub("head.fc.", k)
        k = norm_re.sub("head.norm.", k)

        out[k] = v

    return out


def _create_csatv2(variant: str, pretrained: bool = False, **kwargs) -> CSATv2:
    out_indices = kwargs.pop('out_indices', (1, 2, 3, 4))
    return build_model_with_cfg(
        CSATv2,
        variant,
        pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(out_indices=out_indices, flatten_sequential=True),
        default_cfg=default_cfgs[variant],
        **kwargs,
    )


@register_model
def csatv2(pretrained: bool = False, **kwargs) -> CSATv2:
    return _create_csatv2('csatv2', pretrained, **kwargs)


@register_model
def csatv2_21m(pretrained: bool = False, **kwargs) -> CSATv2:
    # experimental ~20-21M param larger model to validate flexible arch spec
    model_args = dict(
        dims = (48, 96, 224, 448),
        depths = (3, 3, 10, 8),
        transformer_depths = (0, 0, 4, 3)

    )
    return _create_csatv2('csatv2_21m', pretrained, **dict(model_args, **kwargs))