# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import itertools

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F


def _get_comm_group_by_dim(mesh, dim):
    dim_names = mesh.dim_names
    assert dim in dim_names, f"dim '{dim}' not in mesh.dim_names {dim_names}"

    shape = mesh.shape
    dim_idx = dim_names.index(dim)
    ids = mesh.process_ids

    def nest(flat, shape):
        if not shape:
            return flat[0]
        step = int(len(flat) // shape[0])
        return [
            nest(flat[i * step : (i + 1) * step], shape[1:])
            for i in range(shape[0])
        ]

    mesh_nd = nest(ids, shape)

    other_axes = [i for i in range(len(shape)) if i != dim_idx]
    other_ranges = [range(shape[i]) for i in other_axes]

    comm_groups = []
    for index in itertools.product(*other_ranges):
        group = []
        for i in range(shape[dim_idx]):
            idx = list(index)
            idx.insert(dim_idx, i)
            val = mesh_nd
            for j in idx:
                val = val[j]
            group.append(val)
        comm_groups.append(group)

    return comm_groups


def _get_conv_tp_group(x_mesh, x_placements, data_format):
    if data_format == "NCHW":
        shard_axis = 3
    else:
        shard_axis = 2

    axis_name = None
    for i, placement in enumerate(x_placements):
        if placement == dist.Shard(shard_axis):
            axis_name = x_mesh.dim_names[i]
            break

    if not axis_name:
        raise ValueError(
            f"Input tensor placements {x_placements} do not contain a Shard on W axis:{shard_axis}."
        )

    tp_groups = _get_comm_group_by_dim(x_mesh, axis_name)
    rank = dist.get_rank()

    for group in tp_groups:
        if rank in group:
            return axis_name, group

    raise RuntimeError(
        f"Rank {rank} not found in any tensor parallel group for mesh {x_mesh}."
    )


def _ring_conv_halo_exchange(
    local_input_tensor,
    halo_width_to_receive_from_left,
    halo_width_to_receive_from_right,
    left_neighbor_rank,
    right_neighbor_rank,
    current_rank,
    conv_tp_group,
    data_format,
):
    if len(conv_tp_group) == 1:
        return local_input_tensor

    if not (
        len(local_input_tensor.shape) == 4
    ):  # Assuming 4D tensors like NCHW/NHWC
        raise ValueError(
            f"Input tensor is expected to be 4D for NCHW/NHWC formats, "
            f"but got {len(local_input_tensor.shape)}D."
        )

    if data_format == "NCHW":
        width_dim_idx = 3
    elif data_format == "NHWC":
        width_dim_idx = 2
    else:
        raise ValueError(
            f"Unsupported data_format: {data_format}. Must be 'NCHW' or 'NHWC'."
        )

    # Segment to send to the right neighbor (right_neighbor_rank)
    slices_for_send_right = [slice(None)] * 4
    slices_for_send_right[width_dim_idx] = slice(
        -halo_width_to_receive_from_left, None
    )
    segment_to_send_right = local_input_tensor[
        tuple(slices_for_send_right)
    ].contiguous()

    # Segment to send to the left neighbor (left_neighbor_rank)
    slices_for_send_left = [slice(None)] * 4
    slices_for_send_left[width_dim_idx] = slice(
        None, halo_width_to_receive_from_right
    )
    segment_to_send_left = local_input_tensor[
        tuple(slices_for_send_left)
    ].contiguous()

    buffer_for_halo_from_right = paddle.zeros_like(segment_to_send_left)
    buffer_for_halo_from_left = paddle.zeros_like(segment_to_send_right)

    op_isend_to_right = dist.P2POp(
        dist.isend, segment_to_send_right, right_neighbor_rank
    )
    op_isend_to_left = dist.P2POp(
        dist.isend, segment_to_send_left, left_neighbor_rank
    )

    op_irecv_from_right = dist.P2POp(
        dist.irecv, buffer_for_halo_from_right, right_neighbor_rank
    )
    op_irecv_from_left = dist.P2POp(
        dist.irecv, buffer_for_halo_from_left, left_neighbor_rank
    )

    p2p_requests = dist.batch_isend_irecv(
        [
            op_isend_to_right,
            op_isend_to_left,
            op_irecv_from_left,
            op_irecv_from_right,
        ]
    )
    for req in p2p_requests:
        req.wait()

    # Concatenate received halo regions with the local tensor
    if current_rank == conv_tp_group[0]:
        # First rank: original tensor || halo_from_right
        reconstructed_tensor = paddle.concat(
            [local_input_tensor, buffer_for_halo_from_right], axis=width_dim_idx
        )
    elif current_rank == conv_tp_group[-1]:
        # Last rank: halo_from_left || original tensor
        reconstructed_tensor = paddle.concat(
            [buffer_for_halo_from_left, local_input_tensor], axis=width_dim_idx
        )
    else:
        # Middle ranks: halo_from_left || original tensor || halo_from_right
        reconstructed_tensor = paddle.concat(
            [
                buffer_for_halo_from_left,
                local_input_tensor,
                buffer_for_halo_from_right,
            ],
            axis=width_dim_idx,
        )

    return reconstructed_tensor.contiguous()


def _ring_conv_halo_aggregate(
    local_gradient_tensor,
    halo_width_send_left,
    halo_width_send_right,
    left_neighbor_rank,
    right_neighbor_rank,
    current_process_rank,
    conv_tp_group,
    data_format,
):
    if len(conv_tp_group) == 1:
        return local_gradient_tensor

    if data_format == "NCHW":
        width_dim_idx = 3
    elif data_format == "NHWC":
        width_dim_idx = 2
    else:
        raise ValueError(
            f"Unsupported data_format: {data_format}. Must be 'NCHW' or 'NHWC'."
        )

    # Prepare gradient segments to send
    slices_for_send_right = [slice(None)] * 4
    slices_for_send_right[width_dim_idx] = slice(
        -halo_width_send_right, None
    )  # Send the rightmost part
    segment_to_send_right = local_gradient_tensor[
        tuple(slices_for_send_right)
    ].contiguous()

    slices_for_send_left = [slice(None)] * 4
    slices_for_send_left[width_dim_idx] = slice(
        None, halo_width_send_left
    )  # Send the leftmost part
    segment_to_send_left = local_gradient_tensor[
        tuple(slices_for_send_left)
    ].contiguous()

    # Buffers for receiving gradients
    buffer_for_gradient_from_left = paddle.zeros_like(segment_to_send_right)
    buffer_for_gradient_from_right = paddle.zeros_like(segment_to_send_left)

    op_isend_to_right = dist.P2POp(
        dist.isend, segment_to_send_right, right_neighbor_rank
    )
    op_isend_to_left = dist.P2POp(
        dist.isend, segment_to_send_left, left_neighbor_rank
    )
    op_irecv_from_right = dist.P2POp(
        dist.irecv, buffer_for_gradient_from_right, right_neighbor_rank
    )
    op_irecv_from_left = dist.P2POp(
        dist.irecv, buffer_for_gradient_from_left, left_neighbor_rank
    )

    p2p_requests = dist.batch_isend_irecv(
        [
            op_isend_to_right,
            op_isend_to_left,
            op_irecv_from_left,
            op_irecv_from_right,
        ]
    )
    for req in p2p_requests:
        req.wait()

    processed_gradient_tensor = local_gradient_tensor
    # Crop local tensor and aggregate received gradients
    if current_process_rank == conv_tp_group[0]:
        # Crop the part sent to the right neighbor
        crop_slices = [slice(None)] * 4
        crop_slices[width_dim_idx] = slice(None, -halo_width_send_right)
        processed_gradient_tensor = processed_gradient_tensor[
            tuple(crop_slices)
        ]

        # Aggregate gradient received from the right neighbor
        # This is added to the new rightmost part of the processed_gradient_tensor
        agg_slices = [slice(None)] * 4
        agg_slices[width_dim_idx] = slice(-halo_width_send_left, None)

        target_slice = processed_gradient_tensor[tuple(agg_slices)]
        target_slice.add_(buffer_for_gradient_from_right)

    elif current_process_rank == conv_tp_group[-1]:
        # Crop the part sent to the left neighbor
        crop_slices = [slice(None)] * 4
        crop_slices[width_dim_idx] = slice(halo_width_send_left, None)
        processed_gradient_tensor = processed_gradient_tensor[
            tuple(crop_slices)
        ]

        # Aggregate gradient received from the left neighbor
        agg_slices = [slice(None)] * 4
        agg_slices[width_dim_idx] = slice(None, halo_width_send_right)

        target_slice = processed_gradient_tensor[tuple(agg_slices)]
        target_slice.add_(buffer_for_gradient_from_left)

    else:
        # Crop parts sent to both left and right neighbors
        crop_slices = [slice(None)] * 4
        crop_slices[width_dim_idx] = slice(
            halo_width_send_left, -halo_width_send_right
        )
        processed_gradient_tensor = processed_gradient_tensor[
            tuple(crop_slices)
        ]

        # Aggregate gradient received from the right neighbor
        agg_slices_right_edge = [slice(None)] * 4
        agg_slices_right_edge[width_dim_idx] = slice(
            -halo_width_send_left, None
        )
        target_slice_right = processed_gradient_tensor[
            tuple(agg_slices_right_edge)
        ]
        target_slice_right.add_(buffer_for_gradient_from_right)

        # Aggregate gradient received from the left neighbor
        agg_slices_left_edge = [slice(None)] * 4
        agg_slices_left_edge[width_dim_idx] = slice(None, halo_width_send_right)
        target_slice_left = processed_gradient_tensor[
            tuple(agg_slices_left_edge)
        ]
        target_slice_left.add_(buffer_for_gradient_from_left)

    return processed_gradient_tensor.contiguous()


class RingConv2d(paddle.autograd.PyLayer):
    @staticmethod
    def _is_supported(
        input_size, kernel_size, stride, padding, dilation, data_format="NCHW"
    ):
        idx_w_input = -1
        idx_w_kernel = -1

        if data_format == "NCHW":
            # input_size: (N, C, H, W)
            # kernel_size: (OutChannels, InChannels/Groups, KernelH, KernelW)
            idx_w_input = 3
            idx_w_kernel = 3
        elif data_format == "NHWC":
            # input_size: (N, H, W, C)
            # kernel_size: (OutChannels, InChannels/Groups, KernelH, KernelW)
            idx_w_input = 2
            idx_w_kernel = 3
        else:
            raise ValueError(
                f"Unsupported data_format '{data_format}'. Expected 'NCHW' or 'NHWC'."
            )

        dilation_w = dilation[1]
        padding_w = padding[1]
        stride_w = stride[1]

        input_w = input_size[idx_w_input]
        kernel_w = kernel_size[idx_w_kernel]

        if dilation_w != 1:
            # RingConv2d only supports dilation=1.
            # Larger dilation would require enlarged halo regions and more complex communication.
            raise RuntimeError(
                f"Only dilation=1 on the W-dimension is supported for tensor-parallel convolution. "
                f"Got dilation_w={dilation_w} (data_format='{data_format}')."
            )

        if padding_w == 0:
            # To avoid halo exchange when padding=0, we require:
            # - input_w must be divisible by stride_w, so partitions align evenly across ranks.
            # - stride_w == kernel_w, so each kernel operates on disjoint local regions.
            if input_w % stride_w != 0:
                raise RuntimeError(
                    f"When padding_w=0, input_W={input_w} must be divisible by stride_W={stride_w} "
                    f"for tensor-parallel convolution (data_format='{data_format}')."
                )
            if stride_w != kernel_w:
                raise RuntimeError(
                    f"When padding_w=0, stride_W={stride_w} must equal kernel_W={kernel_w} "
                    f"to avoid halo exchange (data_format='{data_format}')."
                )

        else:
            # When padding > 0, halo exchange is needed.
            # To simplify halo logic, we require:
            # - stride_w == 1: ensures each output element is computed from overlapping input,
            #   and no input region is skipped, simplifying halo construction.
            # - kernel_w // 2 <= input_w: prevents the kernel from exceeding local input.
            if stride_w != 1:
                raise RuntimeError(
                    f"When padding_w={padding_w}, stride_W must be 1 for tensor-parallel convolution. "
                    f"Got stride_W={stride_w} (data_format='{data_format}')."
                )
            if kernel_w // 2 > input_w:
                raise RuntimeError(
                    f"Half of kernel_W ({kernel_w // 2}) must not exceed input_W={input_w} "
                    f"to ensure halo region fits (data_format='{data_format}')."
                )

        return True

    @staticmethod
    def forward(
        ctx,
        x,
        weight,
        bias=None,
        stride=1,
        padding=0,
        padding_algorithm=None,
        dilation=1,
        groups=1,
        data_format="NCHW",
        channel_dim=1,
    ):
        rank = dist.get_rank()

        assert RingConv2d._is_supported(
            x.shape, weight.shape, stride, padding, dilation, data_format
        )
        assert x.is_dist(), "Input tensor `x` must be a distributed tensor."

        if not weight.is_dist():
            weight_placements = [
                dist.Replicate() for _ in range(len(x.placements))
            ]
            weight = dist.auto_parallel.api.dtensor_from_local(
                weight, x.process_mesh, weight_placements
            )

        if bias is not None and not bias.is_dist():
            bias_placements = [
                dist.Replicate() for _ in range(len(x.placements))
            ]
            bias = dist.auto_parallel.api.dtensor_from_local(
                bias, x.process_mesh, bias_placements
            )

        ctx.save_for_backward(x, weight, bias)

        x_mesh = x.process_mesh
        x_placements = x.placements
        x = dist.auto_parallel.api.dtensor_to_local(x, x_mesh, x_placements)

        weight = dist.auto_parallel.api.dtensor_to_local(
            weight, weight.process_mesh, weight.placements
        )
        if bias is not None:
            bias = dist.auto_parallel.api.dtensor_to_local(
                bias, bias.process_mesh, bias.placements
            )

        ctx.attrs = (
            stride,
            padding,
            padding_algorithm,
            dilation,
            groups,
            data_format,
        )

        mesh_axis_name, conv_tp_group = _get_conv_tp_group(
            x_mesh, x_placements, data_format
        )
        if padding[1] == 0 or len(conv_tp_group) <= 1:
            final_local_results = paddle._C_ops.conv2d(
                x,
                weight,
                stride,
                padding,
                padding_algorithm,
                dilation,
                groups,
                data_format,
            )
        else:
            # step 0: calculate the required overlap (halo) pixels for the input tensor
            if data_format == "NCHW":
                kernel_width_dim_idx = 3
                output_width_dim_idx = 3
            elif data_format == "NHWC":
                kernel_width_dim_idx = 3
                output_width_dim_idx = 2
            else:
                raise ValueError(
                    f"Unsupported data_format: {data_format}. Must be 'NCHW' or 'NHWC'."
                )

            kernel_width = weight.shape[kernel_width_dim_idx]
            kernel_total_halo_span = kernel_width - 1
            left_halo_width = kernel_total_halo_span // 2
            right_halo_width = kernel_total_halo_span - left_halo_width
            assert left_halo_width + right_halo_width == kernel_total_halo_span

            ctx.mesh_axis_name = mesh_axis_name
            rank_idx = conv_tp_group.index(rank)
            next_rank = conv_tp_group[(rank_idx + 1) % len(conv_tp_group)]
            prev_rank = conv_tp_group[(rank_idx - 1) % len(conv_tp_group)]

            # step 1: reconstruct the local input tensor including halo regions via ring communication
            # `x` is updated here, now including halo data received from neighboring ranks.
            x = _ring_conv_halo_exchange(
                x,
                left_halo_width,
                right_halo_width,
                prev_rank,
                next_rank,
                rank,
                conv_tp_group,
                data_format,
            )

            # step 2: feed the reconstructed local input tensor to the actual computation (op_call)
            local_results_with_halo = paddle._C_ops.conv2d(
                x,
                weight,
                stride,
                padding,
                padding_algorithm,
                dilation,
                groups,
                data_format,
            )

            # step 3: remove extra output portions from the results, generated from processing halo regions
            # `padding[1]` (from outer scope) is assumed here to be the width of the halo/overlap
            # that needs to be trimmed from each side of the output.
            output_halo_trim_width = padding[1]
            width_before_trimming = local_results_with_halo.shape[
                output_width_dim_idx
            ]

            if data_format == "NCHW":
                if rank == conv_tp_group[0]:
                    final_local_results = local_results_with_halo[
                        :,
                        :,
                        :,
                        : width_before_trimming - output_halo_trim_width,
                    ]
                elif rank == conv_tp_group[-1]:
                    final_local_results = local_results_with_halo[
                        :, :, :, output_halo_trim_width:
                    ]
                else:
                    final_local_results = local_results_with_halo[
                        :,
                        :,
                        :,
                        output_halo_trim_width : width_before_trimming
                        - output_halo_trim_width,
                    ]
            else:
                if rank == conv_tp_group[0]:
                    final_local_results = local_results_with_halo[
                        :,
                        :,
                        : width_before_trimming - output_halo_trim_width,
                        :,
                    ]
                elif rank == conv_tp_group[-1]:
                    final_local_results = local_results_with_halo[
                        :, :, output_halo_trim_width:, :
                    ]
                else:
                    final_local_results = local_results_with_halo[
                        :,
                        :,
                        output_halo_trim_width : width_before_trimming
                        - output_halo_trim_width,
                        :,
                    ]

            ctx.left_halo_width = left_halo_width
            ctx.right_halo_width = right_halo_width
            ctx.output_halo_trim_width = output_halo_trim_width
            ctx.output_width_dim_idx = output_width_dim_idx

        final_local_results = dist.auto_parallel.api.dtensor_from_local(
            final_local_results, x_mesh, x_placements
        )

        return final_local_results.contiguous()

    @staticmethod
    def backward(ctx, grad_out):
        current_rank = dist.get_rank()
        x, weight, bias = ctx.saved_tensor()

        x_stop_gradient = x.stop_gradient
        weight_stop_gradient = weight.stop_gradient
        bias_stop_gradient = bias.stop_gradient if bias is not None else True

        x_mesh = x.process_mesh
        x_placements = x.placements
        x = dist.auto_parallel.api.dtensor_to_local(x, x_mesh, x_placements)

        weight_mesh = weight.process_mesh
        weight_placements = weight.placements
        weight = dist.auto_parallel.api.dtensor_to_local(
            weight, weight_mesh, weight_placements
        )

        grad_out = dist.auto_parallel.api.dtensor_to_local(
            grad_out, grad_out.process_mesh, grad_out.placements
        )

        if bias is not None:
            bias_mesh = bias.process_mesh
            bias_placements = bias.placements
            bias = dist.auto_parallel.api.dtensor_to_local(
                bias, bias_mesh, bias_placements
            )

        conv_attrs = ctx.attrs
        data_format = conv_attrs[-1]
        padding = conv_attrs[1]

        grad_x = None
        grad_weight = None
        grad_bias = None

        _, conv_tp_group = _get_conv_tp_group(x_mesh, x_placements, data_format)

        if padding[1] == 0 or len(conv_tp_group) <= 1:
            grad_x, grad_weight = paddle._C_ops.conv2d_grad(
                x, weight, grad_out, *conv_attrs
            )
        else:
            rank_idx = conv_tp_group.index(current_rank)
            next_rank = conv_tp_group[(rank_idx + 1) % len(conv_tp_group)]
            prev_rank = conv_tp_group[(rank_idx - 1) % len(conv_tp_group)]

            left_halo_width = ctx.left_halo_width
            right_halo_width = ctx.right_halo_width
            output_halo_trim_width = ctx.output_halo_trim_width
            output_width_dim_idx = ctx.output_width_dim_idx

            # Step 1: Reconstruct `in_tensor_augmented` (original input to local conv in forward)
            in_tensor_augmented = _ring_conv_halo_exchange(
                x,
                left_halo_width,
                right_halo_width,
                prev_rank,
                next_rank,
                current_rank,
                conv_tp_group,
                data_format,
            )

            # Step 2: Pad `grad_out` to match the output shape of conv on augmented input
            padding_w = padding[1]
            if data_format == "NCHW":
                if current_rank == conv_tp_group[0]:
                    padding_list = [0, padding_w]
                elif current_rank == conv_tp_group[-1]:
                    padding_list = [padding_w, 0]
                else:
                    padding_list = [padding_w, padding_w]
            else:
                if current_rank == conv_tp_group[0]:
                    padding_list = [0, padding_w, 0, 0]
                elif current_rank == conv_tp_group[-1]:
                    padding_list = [padding_w, 0, 0, 0]
                else:
                    padding_list = [padding_w, padding_w, 0, 0]

            grad_out_padded = F.pad(
                grad_out,
                padding_list,
                mode="constant",
                value=0.0,
                data_format=data_format,
            )

            # Step 3: Local backward computation using augmented/padded tensors
            # `padding` here is the original conv padding from forward.
            grad_x_augmented, grad_weight = paddle._C_ops.conv2d_grad(
                in_tensor_augmented, weight, grad_out_padded, *conv_attrs
            )

            # Step 4: Aggregate "halo" regions for grad_input
            if not x_stop_gradient:
                grad_x = _ring_conv_halo_aggregate(
                    grad_x_augmented,
                    left_halo_width,
                    right_halo_width,
                    prev_rank,
                    next_rank,
                    current_rank,
                    conv_tp_group,
                    data_format,
                )

        if bias is not None:
            sum_axes = [0, 2, 3] if data_format == "NCHW" else [0, 1, 2]
            grad_bias = paddle.sum(grad_out, axis=sum_axes, keepdim=True)
            grad_bias = grad_bias.reshape(bias.shape)

        if grad_x is not None:
            grad_x = dist.auto_parallel.api.dtensor_from_local(
                grad_x, x_mesh, x_placements
            )

        # Note(luchang): With input X sharded along tp_axis_name and weight W replicated,
        # the locally computed grad_weight is only a partial sum for the full dL/dW,
        # as dL/dW depends on contributions from all input shards.
        # Aggregation across TP ranks is therefore necessary. Partial(ReduceSum)
        # declares this averaging intent, and reshard to Replicate() executes
        # the AllReduce-average, making the correct averaged grad_weight available
        # and replicated on all TP ranks.
        tp_axis_name, _ = _get_conv_tp_group(x_mesh, x_placements, data_format)
        for idx, axis_name in enumerate(weight_mesh.dim_names):
            if axis_name == tp_axis_name:
                weight_placements[idx] = dist.Partial(dist.ReduceType.kRedSum)
                if bias is not None:
                    bias_placements[idx] = dist.Partial(dist.ReduceType.kRedSum)

        grad_weight = dist.auto_parallel.api.dtensor_from_local(
            grad_weight, weight_mesh, weight_placements
        )
        # do allreduce to get right grad_weight
        grad_weight = dist.reshard(
            grad_weight,
            weight_mesh,
            [dist.Replicate() for _ in range(len(weight_placements))],
        )

        if bias is not None:
            grad_bias = dist.auto_parallel.api.dtensor_from_local(
                grad_bias, bias_mesh, bias_placements
            )
            # do allreduce to get right grad_bias
            grad_bias = dist.reshard(
                grad_bias,
                weight_mesh,
                [dist.Replicate() for _ in range(len(bias_placements))],
            )

        if x_stop_gradient:
            grad_x = None
        if weight_stop_gradient:
            grad_weight = None
        if bias_stop_gradient:
            grad_bias = None

        if bias is not None:
            return grad_x, grad_weight, grad_bias

        return grad_x, grad_weight
