""" Non-Local Attention Pooling Layers

A collection of global pooling layers that go beyond simple avg/max pooling.

LSEPool - LogSumExp pooling, a smooth approximation between avg and max pooling
SimPool - Attention-based pooling from 'Keep It SimPool' (ICCV 2023)

Based on implementations from:
* LSE Pooling: custom implementation by Bill Psomas
* SimPool: https://arxiv.org/abs/2309.06891 - 'Keep It SimPool: Who Said Supervised Transformers
    Suffer from Attention Deficit?' by Bill Psomas et al.

Hacked together by / Copyright 2024 Ross Wightman, original code by Bill Psomas
"""
from typing import Optional, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from .config import use_fused_attn


class LsePlus2d(nn.Module):
    """LogSumExp (LSE) Pooling for 2D inputs.

    A smooth approximation to max pooling that provides a learnable interpolation between
    average and max pooling. When r is large, LSE approaches max pooling; when r is small,
    it approaches average pooling.

    Implements: (1/r) * log((1/n) * sum(exp(r * (x - x_max)))) + x_max

    The x_max subtraction provides numerical stability.
    """

    def __init__(
            self,
            r: float = 10.0,
            r_learnable: bool = True,
            flatten: bool = True,
            device=None,
            dtype=None,
    ):
        """
        Args:
            r: Initial value of the pooling parameter. Higher = closer to max pooling.
            r_learnable: If True, r is a learnable parameter.
            flatten: If True, flatten spatial dims in output.
        """
        super().__init__()
        if r_learnable:
            self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype))
        else:
            self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype))
        self.flatten = flatten

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_max = F.adaptive_max_pool2d(x, 1)
        exp_x = torch.exp(self.r * (x - x_max))
        sum_exp = exp_x.mean(dim=(2, 3), keepdim=True)
        out = x_max + (1.0 / self.r) * torch.log(sum_exp)
        if self.flatten:
            out = out.flatten(1)
        return out


class LsePlus1d(nn.Module):
    """LogSumExp (LSE) Pooling for sequence (NLC) inputs.

    A smooth approximation to max pooling that provides a learnable interpolation between
    average and max pooling. When r is large, LSE approaches max pooling; when r is small,
    it approaches average pooling.
    """

    def __init__(
            self,
            r: float = 10.0,
            r_learnable: bool = True,
            device=None,
            dtype=None,
    ):
        """
        Args:
            r: Initial value of the pooling parameter. Higher = closer to max pooling.
            r_learnable: If True, r is a learnable parameter.
        """
        super().__init__()
        if r_learnable:
            self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype))
        else:
            self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, N, C)
        x_max = x.max(dim=1, keepdim=True).values
        exp_x = torch.exp(self.r * (x - x_max))
        sum_exp = exp_x.mean(dim=1, keepdim=True)
        out = x_max + (1.0 / self.r) * torch.log(sum_exp)
        return out.squeeze(1)  # (B, C)


class SimPool2d(nn.Module):
    """SimPool: Simple Attention-Based Pooling for 2D (NCHW) inputs.

    From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?'
    https://arxiv.org/abs/2309.06891

    Uses GAP as query initialization and applies cross-attention between the GAP query
    and spatial features to produce a weighted pooled representation.
    """
    fused_attn: torch.jit.Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 1,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            gamma: Optional[float] = None,
            norm_layer: Optional[Type[nn.Module]] = None,
            device=None,
            dtype=None,
    ):
        """
        Args:
            dim: Input feature dimension (number of channels).
            num_heads: Number of attention heads.
            qkv_bias: If True, add bias to query and key projections.
            qk_norm: If True, apply normalization to queries and keys.
            gamma: If provided, apply power normalization to values with this exponent.
            norm_layer: Normalization layer for patches and optionally qk_norm.
            flatten: If True, flatten output to (B, C).
        """
        super().__init__()
        dd = {'device': device, 'dtype': dtype}
        assert dim % num_heads == 0, 'dim must be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.gamma = gamma
        self.fused_attn = use_fused_attn()

        norm_layer = norm_layer or nn.LayerNorm
        self.norm = norm_layer(dim, **dd)
        self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
        self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
        if qk_norm:
            self.q_norm = norm_layer(self.head_dim, **dd)
            self.k_norm = norm_layer(self.head_dim, **dd)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        N = H * W

        # Reshape to (B, N, C) for attention
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)

        # GAP as query initialization
        q = x.mean(dim=1, keepdim=True)  # (B, 1, C)

        # Normalize patches for keys and values
        x_norm = self.norm(x)

        # Project query and keys
        q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        q, k = self.q_norm(q), self.k_norm(k)

        if self.gamma is not None:
            # Power normalization on values
            v_min = v.amin(dim=-2, keepdim=True)
            v_shifted = v - v_min + 1e-6
            if self.fused_attn:
                attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma))
            else:
                attn = (q * self.scale) @ k.transpose(-2, -1)
                attn = attn.softmax(dim=-1)
                attn_out = attn @ v_shifted.pow(self.gamma)
            out = attn_out.pow(1.0 / self.gamma)
        else:
            if self.fused_attn:
                out = F.scaled_dot_product_attention(q, k, v)
            else:
                attn = (q * self.scale) @ k.transpose(-2, -1)
                attn = attn.softmax(dim=-1)
                out = attn @ v

        # (B, num_heads, 1, head_dim) -> (B, C) or (B, C)
        out = out.transpose(1, 2).reshape(B, C)
        return out


class SimPool1d(nn.Module):
    """SimPool: Simple Attention-Based Pooling for sequence (NLC) inputs.

    From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?'
    https://arxiv.org/abs/2309.06891

    Uses GAP as query initialization and applies cross-attention between the GAP query
    and sequence tokens to produce a weighted pooled representation.
    """
    fused_attn: torch.jit.Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 1,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            gamma: Optional[float] = None,
            norm_layer: Optional[Type[nn.Module]] = None,
            device=None,
            dtype=None,
    ):
        """
        Args:
            dim: Input feature dimension.
            num_heads: Number of attention heads.
            qkv_bias: If True, add bias to query and key projections.
            qk_norm: If True, apply normalization to queries and keys.
            gamma: If provided, apply power normalization to values with this exponent.
            norm_layer: Normalization layer for tokens and optionally qk_norm.
        """
        super().__init__()
        dd = {'device': device, 'dtype': dtype}
        assert dim % num_heads == 0, 'dim must be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.gamma = gamma
        self.fused_attn = use_fused_attn()

        norm_layer = norm_layer or nn.LayerNorm
        self.norm = norm_layer(dim, **dd)
        self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
        self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
        if qk_norm:
            self.q_norm = norm_layer(self.head_dim, **dd)
            self.k_norm = norm_layer(self.head_dim, **dd)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        # GAP as query initialization
        q = x.mean(dim=1, keepdim=True)  # (B, 1, C)

        # Normalize tokens for keys and values
        x_norm = self.norm(x)

        # Project query and keys
        q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        q, k = self.q_norm(q), self.k_norm(k)

        if self.gamma is not None:
            # Power normalization on values
            v_min = v.amin(dim=-2, keepdim=True)
            v_shifted = v - v_min + 1e-6
            if self.fused_attn:
                attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma))
            else:
                attn = (q * self.scale) @ k.transpose(-2, -1)
                attn = attn.softmax(dim=-1)
                attn_out = attn @ v_shifted.pow(self.gamma)
            out = attn_out.pow(1.0 / self.gamma)
        else:
            if self.fused_attn:
                out = F.scaled_dot_product_attention(q, k, v)
            else:
                attn = (q * self.scale) @ k.transpose(-2, -1)
                attn = attn.softmax(dim=-1)
                out = attn @ v

        # (B, num_heads, 1, head_dim) -> (B, C)
        out = out.transpose(1, 2).reshape(B, C)
        return out
