
    x-jPW                    B   d Z ddlmZ ddlZddlZddlZddlmZ ddlm	Z	 ddl
mZmZ ddlZerddlm
Z g Z G d de          Z G d	 d
e          Zd#dZd Zd$dZd%dZd Zd$dZd%dZ ej                    ai ad Zd%dZej         ddfd&d!Z!ej"        ddfd'd"Z#dS )(z#
Utilities of Auto SParsity (ASP).
    )annotationsN)Enum)permutations)TYPE_CHECKINGAnyc                      e Zd ZdZdZdZdZdS )MaskAlgoz
    A collection of all mask generating algorithms.
    There currently are three algorithms, `MASK_1D`, `MASK_2D_GREEDY` and `MASK_2D_BEST`
    get_mask_1dget_mask_2d_greedyget_mask_2d_bestN)__name__
__module____qualname____doc__MASK_1DMASK_2D_GREEDYMASK_2D_BEST     Y/var/www/html/banglarbhumi/venv/lib/python3.11/site-packages/paddle/incubate/asp/utils.pyr	   r	   $   s)         
 G)N%LLLr   r	   c                  2    e Zd ZdZdZdZed	d            ZdS )
CheckMethodzz
    A collection of all sparsity checking approaches.
    There currently are two methods, `CHECK_1D` and `CHECK_2D`
    check_mask_1dcheck_mask_2d	mask_algor	   returnc                    t          | t                    s
J d            | t          j        k    rt          j        S t          j        S )a  
        Get sparsity checking method by mask generating algorithm.

        Args:
            mask_algo (MaskAlgo): The algorithm of mask generating.
        Returns:
            CheckMethod: The corresponded sparsity checking method.
        Examples:
            .. code-block:: python

                >>> import numpy as np
                >>> from paddle.incubate.asp import CheckMethod, MaskAlgo
                >>> print(CheckMethod.get_checking_method(MaskAlgo.MASK_1D))
                CheckMethod.CHECK_1D
                >>> print(CheckMethod.get_checking_method(MaskAlgo.MASK_2D_GREEDY))
                CheckMethod.CHECK_2D
                >>> print(CheckMethod.get_checking_method(MaskAlgo.MASK_2D_BEST))
                CheckMethod.CHECK_2D
        z!mask_algo should be MaskAlgo type)
isinstancer	   r   r   CHECK_1DCHECK_2D)r   s    r   get_checking_methodzCheckMethod.get_checking_method8   sM    * )X.. 	
 	
/	
 	
. (((''''r   N)r   r	   r   r   )r   r   r   r   r   r    staticmethodr!   r   r   r   r   r   /   sH         
 HH( ( ( \( ( (r   r   xnpt.NDArray[Any]r   floatc                    |                                  }t          t          j        |          d         j                  |j        z  S )a  

    Return the density of the input tensor.

    Args:
        x (nparray): The input tensor.

    Returns:
        float, The density of :attr:`x`.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> import numpy as np

            >>> x = np.array([[0, 1, 3, 0],
            ...             [1, 1, 0, 1]])
            >>> out = paddle.incubate.asp.calculate_density(x)
            >>> print(out)
            0.625

    r   )flattenr%   npnonzerosize)r#   x_flatteneds     r   calculate_densityr,   V   s9    0 ))++KK((+011K4DDDr   c                   t          | j                  dk    s
J d            | j        d         |z  }| j        d         |z  dk    rgt          j        | j        d         | j        d         ||z
  z   f          }| |ddd| j        d         f<   |j        }|                    d|          |fS |                     d|          | j        fS )a  
    Reshape the input 2D matrix to shape (-1, m).
    If the second dimension of :attr:`mat` is not a multiples of :attr:`m`,
    then this function would pad the remainder with 0 before reshaping.

    .. math::

        remainder = mat.shape[1] % m

    Args:
        mat (nparray): The input 2D matrix.
        m (int): The second dimension of reshaped matrix.
    Returns:
        tuple: A pair of the reshaped and padded matrix and the shape of padded matrix (non-reshaping).
       $The input mat should be a 2D matrix!   r   N)lenshaper(   zerosreshape)matm	remainder
mat_paddedr3   s        r   _reshape_1dr:   r   s      sy>>Q F	!q I
y|a!Xsy|SYq\Q]-KLMM
(+
111n	!n$% !!"a((%//{{2q!!39,,r   r6   nintr7   boolc                   t          | j                  dk    r3t          |                     d| j        d                   |          \  }}nt          | |          \  }}|D ]+}t	          j        |          d         j        ||z
  k    r dS ,dS )a  
    Check if every row of the input matrix :attr:`mat` is in 1D `n:m` sparse pattern.
    This function would pad the second dimension of :attr:`mat` by zero
    to be a multiples of :attr:`m` if necessary.

    1D `n:m` sparse pattern: At least :attr:`n` zeros in every :math:`1 \times m` block.

    Args:
        mat (nparray): The input matrix.
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        bool: True if every row of :attr:`mat` is in 1D n:m sparse pattern, else False.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> x = np.array([[0, 1, 3, 0],
          ...               [1, 0, 0, 1]])
          >>> y = sparsity.check_mask_1d(x, 2, 4)
          >>> print(y)
          True

          >>> x = np.array([[0, 1, 5, 4],
          ...               [1, 0, 0, 1]])
          >>> y = sparsity.check_mask_1d(x, 2, 4)
          >>> print(y)
          False

          >>> # x would be padded to shape (2, 8)
          >>> x = np.array([[0, 1, 0, 4, 6],
          ...               [1, 0, 0, 1, 7]])
          >>> y = sparsity.check_mask_1d(x, 2, 4)
          >>> print(y)
          True
    r0   r   FT)r2   r3   r:   r5   r(   r)   r*   )r6   r;   r7   mat_flattenr3   sub_mats         r   r   r      s    N 39~~(Q	!)E)EqIIUU(a00U  :gq!&!a%0055 14r   c                   t          | |          \  }}t          j        |          }t          j        |           }t          |j        d                   D ]Q}||         }t          j        t          j        |                    }	d|||	d|                                         f<   R|                    |          }|ddd| j        d         f         |ddddf<   |S )a  
    Generate 1D `n:m` sparse pattern mask of the input matrix :attr:`mat`
    in row-directory. This function would pad the second dimension of :attr:`mat`
    by zero to be a multiples of :attr:`m` before mask generation.

    1D `n:m` sparse pattern: At least :attr:`n` zeros in every :math:`1 \times m` block.

    Args:
        mat (nparray): The input matrix.
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        nparray: The 1D `n:m` sparse mask of :attr:`mat`.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity
          >>> mat = np.array([[0, 1, 5, 4],
          ...                 [2, 7, 3, 6]])
          >>> mask = sparsity.get_mask_1d(mat, 2, 4)
          >>> print(mask)
          [[0 0 1 1]
          [0 1 0 1]]
          >>> y = sparsity.check_mask_1d(mask, 2, 4)
          >>> print(y)
          True
    r   Nr0   )	r:   r(   	ones_likeranger3   argsortabsolutetolistr5   )
r6   r;   r7   r?   r3   mask_flattenmaskir@   min_order_indicess
             r   r
   r
      s    : %S!,,K<,,L<D;$Q'(( < <a.Jr{7';';<<:;Q)"1"-4466677''..Laaa39Q</0DAAAJKr   c                   t          | j                  dk    s
J d            | j        d         |z  }| j        d         |z  }|dk    r| j        d         n| j        d         ||z
  z   |dk    r| j        d         n| j        d         ||z
  z   f}t          j        |          }| |d| j        d         d| j        d         f<   t          j        |                              d||z            }d}t          d|j        d         |          D ]h}||z   }	t          d|j        d         |          D ]D}
|
|z   }t          j        |||	|
|f                             d                    }|||<   |dz  }Ei||j        fS )a3  
    Reshape the input 2D matrix to shape (-1, :math:`m \times m`).
    In each dimension of :attr:`mat`, if it is not a multiples of :attr:`m`,
    then this function would pad the remainder with 0 before reshaping.

    .. math::

        remainder_0 = mat.shape[0] % m \\
        remainder_1 = mat.shape[1] % m

    Args:
        mat (nparray): The input 2D matrix.
        m (int): The square root of second dimension of reshaped matrix.
    Returns:
        tuple: A pair of the reshaped and padded matrix and the shape of padded matrix (non-reshaping).
    r.   r/   r   r0   Nr1   )r2   r3   r(   r4   emptyr5   rC   squeeze)r6   r7   remainder_0remainder_1	new_shaper9   r?   curr_idx	row_startrow_end	col_startcol_endr@   s                r   _reshape_2drV      s   " sy>>Q F)A,"K)A,"K $q((	!cila+o.N#q((	!cila+o.NI )$$J14J~1~~1~-.(9%%--b!a%88KH1j.q1155  	a-q*"21"5q99 	 	I!mGj9W,i.??@HHLL G %,K!MHH	 
(((r   c           	        t          | |          \  }}|D ]}t          j        t          j        |                    ||                              dk    }t          j        t          j        |d          ||z
  k              dk    r6t          j        t          j        |d          ||z
  k              dk    r dS dS )a  
    Check if every :math:`m \times m` block of the input matrix :attr:`mat` is in 2D `n:m` sparse pattern.
    This function would pad each dimension of :attr:`mat` by zero to be a multiples of
    :attr:`m` if necessary.

    2D `n:m` sparse pattern: At least :math:`n \times n` zeros in every :math:`m \times m` block
    under the constraint of at least :attr:`n` zeros for each row and column.

    Args:
        mat (nparray): The input matrix.
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        bool: True if  every :math:`m \times m` block of the input matrix :attr:`mat` is in 2D `n:m` sparse pattern, else False.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> x = np.array([[0, 8, 9, 0],
          ...               [9, 0, 0, 10],
          ...               [5, 0, 0, 6],
          ...               [0, 4, 6, 0]])
          >>> y = sparsity.check_mask_2d(x, 2, 4)
          >>> print(y)
          True

          >>> x = np.array([[0, 8, 0, 9],
          ...               [9, 0, 0, 10],
          ...               [0, 5, 0, 6],
          ...               [0, 4, 6, 0]])
          >>> y = sparsity.check_mask_2d(x, 2, 4)
          >>> print(y)
          True

          >>> # x would be padded to shape (8, 8)
          >>> x = np.array([[0, 8, 0, 9],
          ...               [9, 0, 7, 0],
          ...               [0, 5, 0, 6],
          ...               [3, 0, 6, 0],
          ...               [1, 1, 0, 1]])
          >>> y = sparsity.check_mask_2d(x, 2, 4)
          >>> print(y)
          True
    r   r0   axisFT)rV   r(   rE   rM   r5   sum)r6   r;   r7   r9   r3   r@   sub_masks          r   r   r     s    ^ $C++J  ;rz'//!Q*?*?@@AAAEF26(+++q1u566!;;F26(+++q1u566!;;554r   c                   t          |           \  }}t          j        |                              d          }t	          t          |                    D ]!}t          j        t          j        ||                             }t          j        ||                   }t          j        |          }	fd|	D             }
t          j
                    }t          j
                    }t	          t          |	          dz
  dd          D ]n}|
|         }||d                  |k    s||d                  |k    r/d||d         |d         f<   ||d         xx         dz  cc<   ||d         xx         dz  cc<   o#t          j        |          }d}t	          d|d                   D ]<}|z   }t	          d|d                   D ]}|z   }||         |||||f<   |dz  }=|d| j        d         d| j        d         f         S )a  
    Greedily generate 2D `n:m` sparse pattern mask of the input matrix :attr:`mat`.
    This function would pad each dimension of :attr:`mat` by zero to be a multiples of :attr:`m` before mask generation.

    2D `n:m` sparse pattern: At least :math:`n \times n` zeros in every :math:`m \times m` block
    under the constraint of at least :attr:`n` zeros for each row and column.
    Greedily generating: For each :math:`m \times m` block, selecting values to keep in descent order.

    Args:
        mat (nparray): The input matrix.
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        nparray: The 2D `n:m` sparse mask of :attr:`mat`.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> mat = np.array([[9, 8, 3, 7],
          ...                 [9, 2, 1, 10],
          ...                 [5, 1, 3, 6],
          ...                 [2, 4, 6, 1]])
          >>> mask = sparsity.get_mask_2d_greedy(mat, 2, 4)
          >>> print(mask)
          [[1. 1. 0. 0.]
          [1. 0. 0. 1.]
          [0. 0. 1. 1.]
          [0. 1. 1. 0.]]
          >>> y = sparsity.check_mask_2d(mask, 2, 4)
          >>> print(y)
          True
    r1   c                >    g | ]}t          |z            |z  fS r   )r<   ).0r#   r7   s     r   
<listcomp>z&get_mask_2d_greedy.<locals>.<listcomp>{  s9      
  
  
$%SQZZQ 
  
  
r   r0   r   g      ?N)rV   r(   
zeros_liker5   rC   r2   rE   rM   rD   collectionsCounterrL   r3   )r6   r;   r7   r9   r3   mask_paddedidxr@   r[   min_order_1d_indicesmin_order_2d_indicesrow_countercol_counterrI   matrix_entryrH   rQ   rR   rS   rT   rU   s     `                  r   r   r   N  sU   J $C++J-
++33B1==KS__%% . .+bjC99:::k#.//!z'22 
  
  
  
)= 
  
  
 ")++!)++s/0014b"== 		. 		.A/2LLO,11LO,119<H\!_l1o56Q(((A-(((Q(((A-((((		. 8E??DH1eAh**  	a-q%(A.. 	 	I!mG9DX9ND7"Ig$556MHH	 #)A,#)A,.//r   c           
        | d|  }|t           v rt           |         S t          j        |          }d|d| <   t          t	          t          |                                                              }||z   }t          j        t          t	          t          ||                                        }|                    d          | k                        d          |k    	                                d         
                    d          }t          j        |j        d         ||f          }||dd                  |dd<   t                                           |t           |<   t                                           |S )a  
    Compute all valid 2D `n:m` sparse patterns.

    2D `n:m` sparse pattern: At least :math:`n \times n` zeros in every :math:`m \times m` block
    under the constraint of at least :attr:`n` zeros for each row and column.

    Args:
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        dictionary: A dictionary with key: *m_n* (string) and value: all valid 2D `n:m` sparse patterns.
    _r0   NrX   r   r1   )_valid_2d_patternsr(   r4   listsetr   rF   asarrayrZ   r)   r5   rL   r3   _valid_2d_patterns_lockacquirerelease)r;   r7   	valid_keypatternsvalidvalid_patternss         r   _compute_valid_2d_patternsrw     sZ     

q

I&&&!),,8A;;!L):):;;<<==h&:d3|Ha'@'@#A#ABBCC lll""a',,!,449WYYqWR[[ 	
 5;q>1a"899$U111X.qqq'')))(69%'')))r   c           
        t          ||          }t          | |          \  }}t          j        |                              d||          }t          j        t          j        ||                    |j        d         ||z            j                  d          }||dd                  |dd<   t          j	        |          }d}	t          d|d         |          D ]<}
|
|z   }t          d|d         |          D ]}||z   }||	         ||
|||f<   |	dz  }	=|d| j        d         d| j        d         f         S )a  
    Generate 2D `n:m` sparse pattern mask of the input matrix :attr:`mat`
    to form sparse matrix with maximum L1 norm .This function would pad each
    dimension of :attr:`mat` by zero to be a multiples of :attr:`m` before mask generation.

    2D `n:m` sparse pattern: At least :math:`n \times n` zeros in every :math:`m \times m` block
    under the constraint of at least :attr:`n` zeros for each row and column.

    *Note*: L1 norm of sparse matrix from `Best` API is greater than or equal to the one from `Greedy`.

    Args:
        mat (nparray): The input matrix.
        n (int): n of `n:m` sparse pattern.
        m (int): m of `n:m` sparse pattern.
    Returns:
        nparray: The 1D `n:m` sparse mask of :attr:`mat`.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> mat = np.array([[2, 8, 9, 9],
          ...                 [9, 1, 3, 9],
          ...                 [5, 6, 3, 9],
          ...                 [2, 4, 6, 9]])
          >>> mask_greedy = sparsity.get_mask_2d_greedy(mat, 2, 4)
          >>> mask_best = sparsity.get_mask_2d_best(mat, 2, 4)
          >>> print("L1 norm of `greedy` sparse matrix", np.multiply(mat, mask_greedy).sum())
          L1 norm of `greedy` sparse matrix 56.0
          >>> print("L1 norm of `best` sparse matrix", np.multiply(mat, mask_best).sum())
          L1 norm of `best` sparse matrix 61.0
    r1   r   r0   rX   N)rw   rV   r(   rB   r5   argmaxmatmulr3   TrL   rC   )r6   r;   r7   rt   r?   r3   rG   pmaxrH   rQ   rR   rS   rT   rU   s                 r   r   r     sa   D *!Q//H$S!,,K<,,44RA>>L9
	+x//q0A1q5IIKLL  D
 tAAAw'LO8E??DH1eAh**  	a-q%(A.. 	 	I!mG9Eh9OD7"Ig$556MHH	 #)A,#)A,.//r   r.      tensor	func_namec                   | j         }| j        }|                     t                    }t	          |t
                    sJ dt          |                       t          t          j	        t                   |j        d          }t          |          dk    r|                    d|d                   }nXt          |          dk    r$|                    |d         |d                   }n!t          |          dk    r,|                    |d         |d         z  |d                   }nt          |          dk    r|                    g d                              |d         |d         z  |d         z  |d                   } ||||	          }|                    |d         |d         |d         |d         g                              g d                              |          S t          d
t          |                      ||||	          }|                    |                              |          S )aI  
    Create `n:m` sparse pattern mask of the input tensor via function given by :attr:`func_name`.
    Currently only support tensor with dimension less than or equal to 4.

    Args:
        tensor (nparray): The input tensor.
        func_name (MaskAlgo, optional): The function name to generate sparse mask. Default is `MaskAlgo.MASK_1D`. All options please refer to `MaskAlgo`.
        n (int, optional): n of `n:m` sparse pattern. Default is 2.
        m (int, optional): m of `n:m` sparse pattern. Default is 4.
    Returns:
        nparray: The `n:m` sparse mask of :attr:`tensor` generated by :attr:`func_name`.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> tensor = np.array([[2, 8, 9, 9],
          ...                    [9, 1, 3, 9],
          ...                    [5, 6, 3, 9],
          ...                    [2, 4, 6, 9]])
          >>> mask_1d = sparsity.create_mask(tensor, func_name=sparsity.MaskAlgo.MASK_1D)
          >>> print(mask_1d)
          [[0 0 1 1]
          [1 0 0 1]
          [0 1 0 1]
          [0 0 1 1]]
          >>> mask_2d = sparsity.create_mask(tensor, func_name=sparsity.MaskAlgo.MASK_2D_BEST)
          >>> print(mask_2d)
          [[0 1 1 0]
          [1 0 0 1]
          [1 1 0 0]
          [0 0 1 1]]
    zMfunc_name argument of create_mask is only accepted as type MaskAlgo. But got Nr0   r   r.      r}   r   r0   r   r.   r;   r7   gThe dimension of input tensor is not supported in create_mask, Only dimension < 4 is supported but got )r3   dtypeastyper%   r   r	   typegetattrsysmodulesr   valuer2   r5   	transpose
ValueError)	r~   r   r;   r7   r3   r   tfuncrH   s	            r   create_maskr     s    P LELEeAi**  	%	??	% 	% * 3;x()/4@@D
5zzQIIaq""	UqIIeAha))	UqIIeAhq)5844	UqKK%%--!HuQx%(*E!H
 
 tAa   LL%(E!HeAhaABBY|||$$VE]]	
 D7:5zzD D
 
 	

 4Q!D<<%%e,,,r   c                ^   | j         }|                     t                    }t          |          t          k    sJ dt          |                       t          t          j        t                   |j	        d          }t          |          dk    r|                    d|d                   }nt          |          dk    r#|                    |d         |d                   }nt          |          dk    r,|                    |d         |d         z  |d                   }n}t          |          dk    rK|                    g d                              |d         |d         z  |d         z  |d         g          }nt          d	t          |                      ||||
          S )a  
    Check if input tensor is in `n:m` sparse pattern via function given by :attr:`func_name`.
    Currently only support tensor with dimension less than or equal to 4.

    Args:
        tensor (nparray): The input tensor.
        func_name (CheckMethod, optional): The function name to generate sparse mask. Default is `CheckMethod.CHECK_1D`. All options please refer to `CheckMethod`.
        n (int, optional): n of `n:m` sparse pattern. Default is 2.
        m (int, optional): m of `n:m` sparse pattern. Default is 4.
    Returns:
        bool: True if tensor pass checking of function given by :attr:`func_name`, else False.
    Examples:
        .. code-block:: python

          >>> import numpy as np
          >>> import paddle.incubate.asp as sparsity

          >>> tensor = np.array([[2, 8, 9, 9],
          ...                    [9, 1, 3, 9],
          ...                    [5, 6, 3, 9],
          ...                    [2, 4, 6, 9]])
          >>> mask_1d = sparsity.create_mask(tensor, func_name=sparsity.MaskAlgo.MASK_1D)
          >>> print(mask_1d)
          [[0 0 1 1]
          [1 0 0 1]
          [0 1 0 1]
          [0 0 1 1]]
          >>> y = sparsity.check_sparsity(mask_1d, func_name=sparsity.CheckMethod.CHECK_1D)
          >>> print(y)
          True
          >>> y = sparsity.check_sparsity(mask_1d, func_name=sparsity.CheckMethod.CHECK_2D)
          >>> print(y)
          True
    zSfunc_name argument of check_sparsity is only accepted as type CheckMethod. But got Nr0   r   r.   r   r}   r   r   r   )r3   r   r%   r   r   r   r   r   r   r   r2   r5   r   r   )r~   r   r;   r7   r3   r   r   s          r   check_sparsityr   H  s   P LEeA	??k)))	%	??	% 	% *)) 3;x()/4@@D
5zzQIIaq""	UqIIeAha))	UqIIeAhq)5844	UqKK%%--1Xa 58+U1X6
 
 D7:5zzD D
 
 	

 4Q!r   )r#   r$   r   r%   )r6   r$   r;   r<   r7   r<   r   r=   )r6   r$   r;   r<   r7   r<   r   r$   )
r~   r$   r   r	   r;   r<   r7   r<   r   r$   )
r~   r$   r   r   r;   r<   r7   r<   r   r=   )$r   
__future__r   ra   r   	threadingenumr   	itertoolsr   typingr   r   numpyr(   numpy.typingnpt__all__r	   r   r,   r:   r   r
   rV   r   r   Lockrp   rl   rw   r   r   r   r   r   r   r   r   <module>r      s1    # " " " " "     



           " " " " " " % % % % % % % %     
& & & & &t & & &$( $( $( $( $($ $( $( $(NE E E E8- - -8/ / / /d' ' ' 'T() () ()V6 6 6 6rF0 F0 F0 F0R )).**  & & &R50 50 50 50t #*	I- I- I- I- I-\ )1	A A A A A A Ar   