
    x-jD                     $   d dl Z d dlZd dlmZ d dlmZ d dlZd dlmZ ddl	m
Z
  e e ej        dd                              Z e e ej        d	d                              Z ej        ee j        d
          Zd Z	 	 	 	 	 d'dZd(dZd Zd Zd Zd Zd Zd Z G d d          Z G d de          Z G d d          Z G d d           Z d! Z!d" Z" G d# d$ej#        j$                  Z% G d% d&ej&        j'                  Z(dS ))    N)deque)Enum)
log_helper   )	CUDAGraph+PADDLE_DEBUG_ENABLE_CUDAGRAPH_LAYER_LOGGING01PADDLE_DEBUG_CUDAGRAPHEDLAYER_FALLBACK_TO_DEFAULTz[%(levelname)s] %(message)s)fmtc                 L    t           sd S t                              |            d S N)enable_debug_printloggerinfoxs    e/var/www/html/banglarbhumi/venv/lib/python3.11/site-packages/paddle/device/cuda/cuda_graphed_layer.pydebug_printr   &   s"     
KKNNNNN    UnnamedTFc                    g }|r|                     |           |d }| t          | d           d S t          | t          j                  r|r:|                     d| j                    |                     d| j                    |r7|                     dt          |                                                       |r!|                     d ||                       t          d	                    |                     d S d S )Nc                 p    t          |                     d          dz                                            S )Nfloat32i  )floatastypesum)ts    r   <lambda>zprint_tensor.<locals>.<lambda>8   s+     3 3d :??AABB r   z is Nonezshape = zplace = zptr = zhash = z | )
appendr   
isinstancepaddleTensorshapeplacehexdata_ptrjoin)r   name
print_meta	print_ptr
print_hashhashoutputs          r   print_tensorr.   ,   s0    F d|BByt%%%&&&&&	Av}	%	% ( 	0MM.QW..///MM.QW../// 	8MM63qzz||#4#466777 	/MM-DDGG--...EJJv&&'''''( (r   printerc                     t           sd S t          |                    dd                     t          t          |            d S )Nd   -)r   r   centerrecursive_applyr.   )r   banners     r   r/   r/   G   sA     c3''(((L!$$$$$r   c                 *    t          |t                    r fd|D             S t          |t                    rt           fd|D                       S t          |t                    r  fd|                                D             S   |          S )Nc                 0    g | ]}t          |          S  r4   .0itemfunctions     r   
<listcomp>z#recursive_apply.<locals>.<listcomp>R   s#    FFFD$//FFFr   c              3   8   K   | ]}t          |          V  d S r   r9   r:   s     r   	<genexpr>z"recursive_apply.<locals>.<genexpr>T   s-      KK_Xt44KKKKKKr   c                 8    i | ]\  }}|t          |          S r8   r9   )r;   keyvaluer=   s      r   
<dictcomp>z#recursive_apply.<locals>.<dictcomp>V   s9     
 
 
U 511
 
 
r   )r    listtupledictitems)r=   	input_vars   ` r   r4   r4   P   s    )T"" 
#FFFFIFFFF	Iu	%	% #KKKKKKKKKK	It	$	$ #
 
 
 
'oo//
 
 
 	

 x	"""r   c                 ~    t          | t          j                  r"|                                 }| j        |_        |S | S r   )r    r!   r"   detachstop_gradient)tensordetached_tensors     r   detach_tensorrO   ^   s9    &&-((  --//(.(<%Mr   c                 6    g fd}t          ||            S )Nc                 j    t          | t          j                  r                    |            d S d S r   )r    r!   r"   r   )argrets    r   r   z!recursive_flatten.<locals>.appendl   s6    c6=)) 	 JJsOOOOO	 	r   r9   )targetr   rS   s     @r   recursive_flattenrU   i   s8    
C     FF###Jr   c                     g t          |           t          t          |                                                    S r   )rU   rF   values)argskwargss     r   recursive_flatten_args_kwargsrZ   {   s<    	4	 	 	511	2	2 r   c                 ,    t          t          |           S r   )r4   rO   r   s    r   r   r      s    ?=!44 r   c                 Z    t          | t          j                  r| j        rdS | j        S dS )zWReturns the gradient of a Paddle Tensor if it's a tensor; otherwise, returns the input.N)r    r!   r"   rL   gradr   s    r   get_grad_tensorr^      s1    !V]## ? 	46M4r   c                   2    e Zd Zd Zd Zd Zd Zd Zd ZdS )CUDAGraphWithStaticInputOutputc                     || _         t                      | _        d| _        d| _        d | _        d | _        d | _        d | _        d S )NF)	num_warmup_stepsr   graphhas_recordedhas_preserved_inputsargs_statickwargs_staticinputs_staticoutputs_staticselfrb   s     r   __init__z'CUDAGraphWithStaticInputOutput.__init__   sM     0[[
!$)!! ""r   c                    | j         s6d| _         || _        || _        t          | j        | j                  | _        dS t          ||          }t          | j        |          D ]\  }}|                    |d           dS )a  
        For the CUDA Graph, it is crucial that the buffer remains address-stable,
        meaning that the buffer addresses for any inputs to the CUDA Graph should not change.
        One solution to achieve this is to preserve all input tensors.

        This function attempts to recursively flatten the input arguments and keyword arguments
        to identify all tensors passed to the layer (though it may still miss some due to other implicit
        ways inputs can be passed to a layer). It then preserves references to these input tensors
        as `self.inputs_static` so that the buffer pointers can be reused later.

        When this method is called subsequently, it copies the values back to the preserved input tensors
        to ensure the buffers are reused.
        TN)re   rf   rg   rZ   rh   zipcopy_)rk   rX   rY   inputsx_staticr   s         r   preserve_or_copyz/CUDAGraphWithStaticInputOutput.preserve_or_copy   s     ( 
	((,D%#D!'D!> $"4" "D 34@@F"4#5v>> ( (!q$''''( (r   c                 ,   |                      ||           | j                                          || j        i | j        | _        | j                                         t          d           | j                                         d| _	        | j        S )NzF[CUDAGraph] Record-Replay Start (Graph is replayed for the first time)T)
rr   rc   capture_beginrf   rg   ri   capture_endr   replayrd   )rk   frX   rY   s       r   recordz%CUDAGraphWithStaticInputOutput.record   s    dF+++
  """a!1HT5GHH
   T	
 	
 	
 	
 ""r   c                     || _         d S r   )ri   )rk   ri   s     r   set_output_staticz0CUDAGraphWithStaticInputOutput.set_output_static   s    ,r   c                     | j         st          d          |                     ||           t          d           | j                                         | j        S )NzGraph should be recorded firstz[CUDAGraph] Replay Start)rd   RuntimeErrorrr   r   rc   rv   ri   )rk   rX   rY   s      r   rv   z%CUDAGraphWithStaticInputOutput.replay   s_      	A?@@@dF+++.///
""r   c                 h    t          j        d|            | j                            |           d S )Nzsave graph to )loggingr   rc   print_to_dot_files)rk   r(   s     r   savez#CUDAGraphWithStaticInputOutput.save   s7    ,d,,---
%%d+++++r   N)	__name__
__module____qualname__rl   rr   rx   rz   rv   r   r8   r   r   r`   r`      sn        # # #( ( (4# # #- - -# # #, , , , ,r   r`   c                       e Zd ZdZdZdZdZdS )CUDAGraphLayerStatusz3Enum to represent the status of a CUDA Graph Layer.r         N)r   r   r   __doc__WARMUPRECORD	CUDAGRAPHr8   r   r   r   r      s#        ==FFIIIr   r   c                   &    e Zd Zd Zd Zd Zd ZdS )CUDAGraphForwardBackwardc                 x    t          |          | _        t          |          | _        t          j        | _        d S r   )r`   forward_graphbackward_graphr   r   statusrj   s     r   rl   z!CUDAGraphForwardBackward.__init__   s2    ;<LMM<=MNN*1r   c                 (    t           j        | _        d S r   )r   r   r   rk   s    r   ru   z$CUDAGraphForwardBackward.capture_end   s    *4r   c                 ,    | j         t          j        k    S r   )r   r   r   r   s    r   is_record_stepz'CUDAGraphForwardBackward.is_record_step       {2999r   c                 ,    | j         t          j        k    S r   r   r   r   r   s    r   is_cuda_graph_stepz+CUDAGraphForwardBackward.is_cuda_graph_step       {2<<<r   N)r   r   r   rl   ru   r   r   r8   r   r   r   r      sP        2 2 2
5 5 5: : := = = = =r   r   c                   B    e Zd ZdZd Zd Zd Zd Zd Zd Z	d Z
d	 Zd
S )CUDAGraphContextz
    Manages the context for CUDA graph execution in layers. This includes handling
    the state of CUDA graph layers, managing forward and backward graphs, and
    tracking the execution steps.
    c                     || _         || _        d| _        t          j        | _        t                      | _        t                      | _        dS )z
        Initializes the CUDA graph context.
        :param layer: The layer to be used in the CUDA graph.
        :param num_warmup_steps: Number of warmup steps before recording starts.
        r   N)	layerrb   _stepr   r   r   r   
data_queuegraph_queue)rk   r   rb   s      r   rl   zCUDAGraphContext.__init__   sG     
 0 
*1  '' !77r   c                     t          | j                  dk    rt          | j                  S | j                                        S )Nr   )lenr   r   rb   popleftr   s    r   	get_graphzCUDAGraphContext.get_graph  s>    t  A%%+D,ABBB#++---r   c                 :    | j                             |           d S r   )r   r   )rk   gs     r   reuse_graphzCUDAGraphContext.reuse_graph  s    """""r   c                 :    | j                             |           d S r   )r   r   )rk   rX   s     r   	push_datazCUDAGraphContext.push_data  s    t$$$$$r   c                 4    | j                                         S r   )r   r   r   s    r   pop_datazCUDAGraphContext.pop_data  s    &&(((r   c                 l    | xj         dz  c_         | j         | j        k    rt          j        | _        d S d S )Nr   )r   rb   r   r   r   r   s    r   warmup_stepzCUDAGraphContext.warmup_step   s7    

a

:....8DKKK /.r   c                 ,    | j         t          j        k    S r   )r   r   r   r   s    r   is_warmup_stepzCUDAGraphContext.is_warmup_step%  r   r   c                 ,    | j         t          j        k    S r   r   r   s    r   r   z#CUDAGraphContext.is_cuda_graph_step(  r   r   N)r   r   r   r   rl   r   r   r   r   r   r   r   r8   r   r   r   r      s         # # #*. . .# # #% % %) ) )9 9 9
: : := = = = =r   r   c                 |   d\  }}t          | t          j                  r| |d         }}nYt          | t          t          f          r=t          | |          D ],\  }}t          |t          j                  r|j        s||}} n-t          |t          j                  rt          |t          j                  sJ ||fS )N)NNr   )r    r!   r"   rE   rF   rn   rL   )ysdysydyvdvs         r   select_y_with_gradr   ,  s    EAr"fm$$ CF2	Bu	&	& S\\ 	 	EAr!V]++ Q_ 22a''IJr6=,I,IIIIb5Lr   c                 0   | \  }}g }t          ||          D ]p\  }}|j        sO|j        -|                    t	          j        |j                             @|                    |j                   [|                    d            qt          |          S r   )rn   rL   r]   r   r!   zerosr#   rF   )rp   grad_inputsdetached_grad_inputs	args_gradr   
detached_xs         r   get_args_gradr   <  s    (.%K%I[*>?? # #: 		#&   j.>!?!?@@@@   1111T""""r   c                   >    e Zd ZdZed             Zed             ZdS )_CUDAGraphedLayerz
    A custom layer that integrates CUDA Graph recording and execution into PaddlePaddle's autograd system.
    It handles forward and backward operations differently based on the CUDA graph layer status.
    c                    |\  }}t          |          }t          |          }t          ||          }||f}t          |d                                           st          rkt          d           t          j                    5   j        |i |}ddd           n# 1 swxY w Y   	                    t          j        d||f           nډ                                }	|	                                r^t          dt          |	                      fd}
 |	j        j        |
g|R i |}	                    t          j        |	||f           nTt          dt          |	                       |	j        j        |i |}	                    t          j        |	d|f           t          d           |                                t          |d           t          |          S )	z
        Handles the forward pass of the layer. It operates differently based on the
        context's status: warmup, recording, or CUDA graph step.
        zForward inputz"[CUDAGraph] Forward Step (Default)Nz%[CUDAGraph] Forward Step (Record) id c                  x    t          j                    5   j        | i |cd d d            S # 1 swxY w Y   d S r   )r!   enable_gradr   )rX   rY   contexts     r   forwardz*_CUDAGraphedLayer.forward.<locals>.forwards  s    +-- > >,w}d=f==> > > > > > > > > > > > > > > > > >s   /33z$[CUDAGraph] Forward Step (Graph) id z[CUDAGraph] Forward Step EndzForward output)rK   rZ   r/   r   *debug_cudagraphedlayer_fallback_to_defaultr   r!   r   r   r   r   r   r   r   idr   rx   r   rv   r   save_for_backward)ctxr   	arg_tupler   rX   rY   r   rp   r   rc   r   s    `         r   r   z_CUDAGraphedLayer.forwardT  sm    !fd||<T6JJ34$o666""$$	9	 <===#%% 3 3!GM426223 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3:D&!LMMMM%%''E##%% OBuIIOOPPP> > > > > /E'.wHHHHHH!!)0%C    N2e99NNOOO.E'.???!!)3UD!D   	2333g&&&#$$$ayys   B  B$'B$c                    |                                  \  }|                                \  }}}}t          ||          \  }}t          ||fd           |t          j        k    rIt          d           |                    |           t          |          }	|	                                 n|t          j
        k    rt          dt          |                      d }
|j                            |
||           t          |          }	|j                            |	           |                                 |                    |           no|t          j        k    rPt          dt          |                      |j                            ||          }	|                    |           nt'          d          t          d           t          |	d           |	S )	z
        Handles the backward pass of the layer. Similar to forward, it handles
        backward based on the context's status: warmup, record, or CUDAGraph.
        zBackward inputz#[CUDAGraph] Backward Step (Default)z&[CUDAGraph] Backward Step (Record) id c                 0    |                      |           d S r   )backward)r   r   s     r   r   z,_CUDAGraphedLayer.backward.<locals>.backward  s    

2r   z%[CUDAGraph] Backward Step (Graph) id zUnknown cuda graph statusz[CUDAGraph] Backward Step EndzBackward output)saved_tensorr   r   r/   r   r   r   r   r   r   r   r   r   rx   rz   ru   r   r   rv   r|   )r   r   r   r   rc   rp   r   r   r   r   r   s              r   r   z_CUDAGraphedLayer.backward  s    %%''
&-&6&6&8&8#"2s++2B)***)000=>>> JJrNNN%f--I!!!!+222LELLMMM    ''!R888 &f--I 229===&&&&+555K5		KKLLL ,33Ar::I&&&&:;;;3444	,---r   N)r   r   r   r   staticmethodr   r   r8   r   r   r   r   N  sU         
 3 3 \3j 0 0 \0 0 0r   r   c                   P     e Zd ZdZddej        j        f fdZd Zd Z	d Z
 xZS )	CUDAGraphedLayera  
    CUDAGraphedLayer: A PaddlePaddle Layer to convert an eager mode model to utilize CUDA Graphs.

    CUDA Graphs provide a way to capture kernel-level operations of a model and play
    them back efficiently, allowing for potential speedups in repetitive computations,
    such as those during training iterations. This layer is a wrapper that enables
    the usage of CUDA Graphs with PaddlePaddle models.

    Overview:
    - The layer encapsulates another layer (the model to be converted).
    - During the first few (num_warmup_steps) iterations, the layer operates in
      eager mode without any CUDA Graphs.
    - After the warmup steps, the layer captures the forward and backward computations
      and replays them using CUDA Graphs in subsequent iterations.

    Usage:
        model = Model()
        graphed_model = CUDAGraphedLayer(model)

    Parameters:
    - layer (paddle.nn.Layer): The PaddlePaddle model/layer to be converted.
    - num_warmup_steps (int): The number of iterations before the CUDA Graph
      capture begins. Default is 3.

    Notes:
    - Restrictions:
        * CPU-GPU Synchronization: Operations that synchronize the CPU with the GPU, like device to host transfers, are not allowed.
        * CPU Work: Any operations on the CPU within the captured graph are not recorded.
        * Memory Address (Pointer) Consistency: Replays consistently read from and write to identical virtual memory addresses.
        * Dynamic Operations:
            - Control Flow: Dynamic control flows, especially those based on CPU data like if/else statements, are prohibited.
            - Tensor Shapes: Dynamic tensor shapes are not supported.

    - Allowed Operations:
        * CUDA RNG Operations: CUDA-based Random Number Generation operations are allowed.
    r   r   c                     t                                                       t          ||          | _        |                     dt          |          j         |           d S )NzGraphed )superrl   r   r   add_sublayertyper   )rk   r   rb   	__class__s      r   rl   zCUDAGraphedLayer.__init__  sY    '/?@@;T%[[%9;;UCCCCCr   c                 V    t          ||          }t          j        | j        ||fg|R  S r   )rZ   r   applyr   )rk   rX   rY   r   s       r   r   zCUDAGraphedLayer.forward  s>    3D&AA &L4.
+6
 
 
 	
r   c                 4    | j                                         S r   )r   r   r   s    r   r   zCUDAGraphedLayer.is_warmup_step  s    |**,,,r   c                 4    | j                                         S r   )r   r   r   s    r   r   z#CUDAGraphedLayer.is_cuda_graph_step  s    |..000r   )r   )r   r   r   r   r!   nnLayerrl   r   r   r   __classcell__)r   s   @r   r   r     s        # #JD Dfio D D D D D D

 
 
- - -1 1 1 1 1 1 1r   r   )r   TFTN)r/   ))r~   oscollectionsr   enumr   r!   paddle.baser   graphsr   boolintgetenvr   r   
get_loggerr   INFOr   r   r.   r/   r4   rO   rU   rZ   rK   r^   r`   r   r   r   r   r   autogradPyLayerr   r   r   r   r8   r   r   <module>r      s    				              " " " " " "       TC		?EEFF   .2TC		EsKKLL. . * 
	gl =
 
 

   
	( ( ( (6% % % %# # #    $   
5	4  G, G, G, G, G, G, G, G,V    4   = = = = = = = = 6= 6= 6= 6= 6= 6= 6= 6=r     $m m m m m/ m m m`61 61 61 61 61vy 61 61 61 61 61r   