
    ϑiY                     j    S SK r S SKrS SKJrJrJr  S SKJr  S SKJ	r	J
r
Jr  S SKJr   " S S\5      rg)    N)core	frameworkunique_name)append_backward)Variablein_dygraph_modeprogram_guard)	Optimizerc                       \ rS rSrSrS rS rS r\R                  S 5       r
S rS rS	 rS
 rS rS rS rS rS rS rS rS rS rS rS rSS jr    SS jrS r SS jrSrg)RecomputeOptimizer   a  
    :api_attr: Static Graph

Recompute Optimizer Wrapper

Normally, a training step contains three sub-steps: first, run forward
Operators to calculate the loss; second, run backward Operators to
calculate gradient of the parameters; third, apply optimization method
to update the value of the parameters.

In the forward computation process, all variables that are needed by
backward computation process will be kept in memory, which occupy a great
amount of memory when the network becomes very deep.

Recompute split the network to k segments. In each segment, It will
recompute the forward Operators, before running backward operators. It is
very helpful for saving memory.

The Variables that separate a network to segments are called as checkpoints,
and users should set it manually. The usage is very simple:

Args:
    optimizer (Optimizer): The optimizer that is applied to parameters.

Examples:
    .. code-block:: python

        >>> import paddle
        >>> import numpy as np

        >>> paddle.enable_static()

        >>> def gen_data():
        ...     return {"x": np.random.random(size=(32, 32)).astype('float32'),
        ...     "y": np.random.randint(2, size=(32, 1)).astype('int64')}
        >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
        ...     print(input_x)
        ...     fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
        ...     prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
        ...     cost = paddle.nn.functional.cross_entropy(
        ...         input=prediction, label=input_y,
        ...         reduction='none', use_softmax=False
        ...     )
        ...     sum_cost = paddle.mean(cost)
        ...     return sum_cost, fc_1, prediction
        >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
        >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
        >>> cost, fc_1, pred = mlp(input_x, input_y)

        >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
        >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
        >>> sgd._set_checkpoints([fc_1, pred])
        >>> sgd.minimize(cost)

        >>> print("Finished optimize")
        Finished optimize
        >>> place = paddle.CPUPlace()
        >>> exe = paddle.static.Executor(place)
        >>> exe.run(paddle.static.default_startup_program())
        >>> step = 10

        >>> for i in range(step):
        ...     cost_val = exe.run(feed=gen_data(),
        ...             program=paddle.static.default_main_program(),
        ...             fetch_list=[cost.name])
        ...     print("step=%d cost=%f" % (i, cost_val[0]))
        var x : DENSE_TENSOR.shape(-1, 32).dtype(float32).stop_gradient(True)
        Finished optimize
        step=0 cost=0.737203
        step=1 cost=1.308077
        step=2 cost=0.768422
        step=3 cost=1.239475
        step=4 cost=0.882643
        step=5 cost=0.738027
        step=6 cost=0.819374
        step=7 cost=0.818534
        step=8 cost=0.753692
        step=9 cost=0.787448

c                     [        5       (       a  [        S5      eXl        S U l        U R                  R                  U l        U R                  R
                  U l        SU l        g )Nz-In dygraph, don't support RecomputeOptimizer.F)r   	Exception
_optimizer_checkpoints_learning_rate_learning_rate_mapenable_offload)self	optimizers     c/var/www/html/banglarbhumi/venv/lib/python3.13/site-packages/paddle/incubate/optimizer/recompute.py__init__RecomputeOptimizer.__init__j   sP    KLL# "oo<<"&//"D"D#    c                     [        U[        5      (       d   S5       eU H%  n[        U[        [        45      (       a  M    S5       e   Xl        g)z:
Args:
    checkpoints (list): List of Variable or string
z=_checkpoints should be a list of Variable or a list of StringN)
isinstancelistr   strr   )r   checkpointsckpts      r   _set_checkpoints#RecomputeOptimizer._set_checkpointss   sV    
 +t,, 	
K	
,  DdXsO44 O4   (r   c                     SU l         g )NT)r   r   s    r   _enable_offload"RecomputeOptimizer._enable_offload   s
    "r   c                     [        S5      e)a  
    :api_attr: Static Graph

load function is not supported by Recompute Optimizer for now.
:return: None

Args:
    state_dict: the dict load by load_persistable method

Examples:
    .. code-block:: python

        >>> import paddle

        >>> paddle.enable_static()
        >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
        ...     fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
        ...     prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
        ...     cost = paddle.nn.functional.cross_entropy(
        ...         input=prediction, label=input_y,
        ...         reduction='none', use_softmax=False
        ...     )
        ...     sum_cost = paddle.mean(cost)
        ...     return sum_cost, fc_1, prediction

        >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
        >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
        >>> cost, fc_1, pred = mlp(input_x, input_y)
        >>> print("Finished FF")
        Finished FF

        >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
        >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
        >>> sgd._set_checkpoints([fc_1, pred])
        >>> try:
        ...     state_dict = {}
        ...     sgd.load(state_dict)
        >>> except NotImplementedError as e:
        ...     print(e)
        load function is not supported by Recompute Optimizer for now
z=load function is not supported by Recompute Optimizer for now)NotImplementedError)r   
state_dicts     r   loadRecomputeOptimizer.load   s    V "K
 	
r   c                 4    U R                   R                  US9$ )a  
call apply_gradients function of self._optimizer.

Args:
    params_grads (list): list of (param, grad) pair to do optimization.

Returns:
    list: A list of operators appended to the current program.

Examples:
    .. code-block:: python

        >>> import paddle
        >>> import paddle.base.framework as framework

        >>> paddle.enable_static()

        >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
        ...     fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
        ...     prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
        ...     cost = paddle.nn.functional.cross_entropy(
        ...         input=prediction, label=input_y,
        ...         reduction='none', use_softmax=False
        ...     )
        ...     sum_cost = paddle.mean(cost)
        ...     return sum_cost, fc_1, prediction

        >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
        >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
        >>> cost, fc_1, pred = mlp(input_x, input_y)
        >>> print("Finished FF")
        Finished FF

        >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
        >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
        >>> sgd._set_checkpoints([fc_1, pred])
        >>> params_grads = sgd.backward(
        ...     cost,
        ...     startup_program=None,
        ...     parameter_list=None,
        ...     no_grad_set=None)

        >>> program = cost.block.program
        >>> with framework.program_guard(program, None):
        ...     optimize_ops = sgd.apply_gradients(params_grads)

        >>> print("Finished apply gradients")
        Finished apply gradients
)params_grads)r   apply_gradients)r   r-   s     r   r.   "RecomputeOptimizer.apply_gradients   s    f ..L.IIr   c                    [         R                  " US-   5      n[         R                  " US-   5      nU R                  R                  5       R	                  UU R
                  U R                  R                  5       R                  U5      R                  SSS9nU R                  R                  5       R	                  UU R
                  U R                  R                  5       R                  U5      R                  SSS9nX#4$ )Nz@Pinnedz@FetchFTnameshapedtypepersistablestop_gradient)r   generate_main_programglobal_block
create_varcheckpoint_shapevarr4   )r   varnamepinned_var_namefetched_var_name
pinned_var	fetch_vars         r   _create_varsRecomputeOptimizer._create_vars   s    %..w/BC&//(0BC''446AA ''$$11377@FF B 

 &&335@@!''$$11377@FF A 
	 00r   c                    SnUR                  5       nU R                  R                  5       n[        R                  R                  5       nU H  nU R                  R                  5       R                  U5      nUR                  UU R                  U R                  R                  5       R                  UR                  5      R                  SSS9nUR                  SSU0SUR                  SUR                  S	S
SSXR0S9  M     g)a   
add fill_constant_ops to the end of the prog

we should fill the pinned vars before running the main_prog
to instantiate their tensor hold_, which could tell us whether
the host memory could hold all the checkpoints from all the
GPU devices in this node.
r   FTr1   fill_constantOutr3   r4   valueg        
place_type   )typeoutputsattrsN)r9   checkpoint_name2pinned_namevaluesr   op_proto_and_checker_makerkOpRoleAttrNamer8   r<   r:   r;   r2   r4   	append_opr3   )	r   startup_programop_roleblockfill_constant_varsOP_ROLE_KEYr=   r<   r@   s	            r   _append_fill_constant_ops,RecomputeOptimizer._append_fill_constant_ops   s     ,,.!==DDF55EEG)G$$11377@C))++((557;;CHHEKK!" * J OO$(SYYSYYS !  
 *r   c           
      :   [         R                  R                  5       nU R                  R	                  USSU R
                  R                  5       R                  U5      /0SU R
                  R                  5       R                  U5      /0S[        U5      Xd0S9  g )NmemcpyXrF   dst_place_type)rJ   inputsrK   rL   )	r   rO   rP   rT   _insert_op_without_syncr8   r9   r<   int)r   
insert_idxsrc_varnamedst_varnamerS   r\   rV   s          r   _insert_async_memcpy_op*RecomputeOptimizer._insert_async_memcpy_op"  s     55EEG

**$,,99;??LMN**779==kJK $S%8+O 	+ 	
r   c                     X R                   ;   d   SU S35       eU R                   U   nU R                  U   nU R                  XUSS5        g )NzTry to fetch z/ from Pinned Memory, but it is NOT a checkpoint   )rM   checkpoint_name2fetch_namerc   )r   idxr=   pinned_varnamefetch_varnames        r   _insert_fetch_op#RecomputeOptimizer._insert_fetch_op0  s^    ::: 	
G9$ST	
: 99'B77@$$S-ANr   c                 ~    X R                   ;   d   SU S35       eU R                   U   nU R                  XUSS5        g )NzTry to offload z- to Pinned Memory, but it is NOT a checkpointr   rI   )rM   rc   )r   rh   r=   ri   s       r   _insert_offload_op%RecomputeOptimizer._insert_offload_op9  sN    ::: 	
gY&ST	
: 99'B$$S>1aHr   c                     g N )r   op_idxcheckpoint_names      r   _insert_sync_op"RecomputeOptimizer._insert_sync_op@      r   c                     [        U R                  5      S:  d   S5       eU R                  R                  S5      n[        R                  " SU S35        SU4U R
                  U'   U$ )Nr   z#Could NOT found checkpoint to fetchzRecord fetch []fetch)lenun_fetch_checkpoint_namespoploggingdebugidx2insertionsr   rh   rt   s      r   _record_fetch_op#RecomputeOptimizer._record_fetch_opD  sn    4112Q6 	
1	
6 88<<R@&7q9:$+_#=C r   c                     U R                   R                  S5      nX#:X  d   SU SU S35       e[        R                  " SU S35        SU4U R                  U'   g )Nr   zexpected to offload [z] but got [rz   zRecord offload [offload)un_offload_checkpoint_namesr~   r   r   r   )r   rh   rt   expected_checkpoint_names       r   _record_offload_op%RecomputeOptimizer._record_offload_opN  sp    #'#C#C#G#G#J : 	
#$<#=[HYYZ[	
: 	((9;<$-#?C r   c                     X R                   ;  d   SU S35       eU R                   R                  U5        [        R                  " SU S35        SU4U R                  U'   g )NzTry to sync the checkpoint [z] twicezRecord offload sync [rz   sync)synced_checkpointsaddr   r   r   r   s      r   _record_sync_op"RecomputeOptimizer._record_sync_opV  sj    &=&== 	
*?*;7C	
= 	##O4-o->a@A$*O#<C r   c                    0 U l         U R                  S S  U l        U R                  R                  S5        U R                  S S  n0 U l        U R                   H  nSU R                  U'   M     [        U R                  R                  5      U l        [        U R                  R                  5       H5  u  p4[        UR                  R                  S5      5      S:X  d  M/  X0l          O   U R                  [        U R                  R                  5      :  d   S5       eU R                  U R                  5      nS n[        U R                  R                  U R                  S  5       H  u  ptU R                  U-   nUR                  R                  5       nU H  n	X;   d  M
  XR                  ;  a  U R                  U	   S:X  a%  Un
XR                  S   :w  a  U R                  U5      nW
U	:X  d   SU
 SU	 S35       eU R                  R                  U   R                  U	U R                   U	   5        U R                  U	==   S-  ss'   M  [#        S	U	 S
35      e   M     [        U R                  5      S:X  d   U R                   S35       eg )Nry   r   rS   rf   z#Could NOT found backward op in progz&Current recompute segment should use [z] BUT got [rz   zuse checkpoint [z] before fetch in BW# checkpoints have NOT been Recorded)r   sorted_checkpoint_namesr}   r~   checkpoint_usage_countr|   rT   opsbw_start_op_idx	enumerater_   descattrr   input_arg_names_rename_inputrg   
ValueError)r   need_fetch_checkpoint_namesrt   rh   opfetched_checkpoint_varnamelast_last_fetch_checkpointi
input_vars	input_varsecond_to_last_fetch_checkpoints              r   _parse_backward"RecomputeOptimizer._parse_backward^  sv    )-)E)Ea)H&&&**2.&*&D&DQ&G#&(##==O;<D''8  >  #4::>>2 0GC277<<	*+q0'*$ 1
 ##c$**..&99 	
1	
9
 &*%:%:4;O;O%P"%)"tzz~~d.B.B.DEFEA&&*C002J'	; (F(FF66yAQF !; <  ),H,H,KK$($9$9#$> !;
  ?)K DEdDeepqzp{{|}K 

s+99% ;;IF 33I>!C>(.yk9MN 5 (	 GD 4112a7 	
--..QR	
7r   c                    [        U R                  5      S:X  a  g [        U R                  R                  5      n[	        [        U R                  U5      5       H  nX R                  ;   d  M  U R                  U   u  p4US:X  a:  U R                  X$5        [        R                  " SU S35        U R                  U	 Me  US:X  d  Mm  U R                  X$5        [        R                  " SU S35        M     U R                  R                  5         [        U R                  5      S:X  d5   U R                  R                  5        Vs/ s H  oUS   PM	     sn S35       eg s  snf )	Nr   r{   Insert [z] fetch op.r   zSync [rf   z checkpoints left un-Fetched)r|   r   rT   r   reversedranger   rk   r   r   ru   _sync_with_cpprN   )r   total_oprs   	operationrt   eles         r   _update_backward#RecomputeOptimizer._update_backward  s/   t""#q(tzz~~&uT%9%98DEF,,,-1-@-@-H*	'))&BMMH_,=["IJ++F3&(((AMMF?*;;"GH F 	

!!#4&&'1, 	
"&"5"5"<"<">?">3A">?@@\]	
,?s   Ec           
         0 U l         U R                  S S  U l        U R                  R                  S5      nU R                  S S  n0 U l        U R                   H  nSSS.U R                  U'   M     [        5       U l        [        U R                  R                  5      U l
        [        U R                  R                  5       H5  u  pE[        UR                  R                  S5      5      S:X  d  M/  X@l
          O   U R                  [        U R                  R                  5      :  d   S5       eS n[        U R                  R                  U R                  U R                   5       GH  u  puU R                  U-   nUR                  R!                  5       nUR                  R#                  5       n	U GHv  n
X;   a  [        U5      S:X  d   SU
 SU S	35       eXR                  ;   a{  Ub`  U R                  U   S
   S:X  a  U R%                  XF5        O8U R                  U   S   nUS:  d   SU S35       eU R%                  US-   U5        U R'                  US-   U
5        U
nO[)        SU
 S	35      eX:X  d  M  [        U5      S:X  d   SU
 SU S	35       eUU R                  S   :X  d   SU SU R                  S    SU S	35       eU R                  U   S   S:X  a  U R%                  XF5        GM>  U R                  U   S   nUS:  d   SU S35       eU R%                  US-   U5        GMy     U	 HO  nX;   d  M
  XR                  ;  d   SU S35       eU R                  U   S
==   S-  ss'   X@R                  U   S'   MQ     GM     [        U R                  5      S:X  d   U R*                   S35       e[        U R                  5      [        U5      :X  d)   [        U5      [        U R                  5      -
   S35       eg )Nry   r   )countrh   rS   z"Could NOT found Forward op in progrf   z;checkpoint should be the only Output of a certain op, but [z] is from [rz   r   rh   zlast_usage_idx of checkpoint [z] should large than 0z4There should be just ONE op that output checkpoint [z$the last offload checkpoint before [z] is suppose to be [z], but got [zcheckpoint [z] used after syncr   )r   r   r   r~   checkpoint_usage_count_and_idxsetr   r|   rT   r   fw_start_op_idxr   r_   r   r   r   output_arg_namesr   r   r   r   r}   )r   last_checkpointneed_offload_checkpoint_namesrt   rh   r   last_offload_checkpointr   output_varsr   
output_varlast_usage_idxr   s                r   _parse_forward!RecomputeOptimizer._parse_forward  s    +/+G+G+J(::>>rB(,(H(H(K%.0+#??ODD//@  @
 #&%"4::>>2 0GC277<<	*+q0'*$ 1
 ##c$**..&99 	
0	
9 #'JJNN4//$2F2FG
EA &&*C''224K002J)
>{+q0 UV`Uaalmolppqr0 "%E%EE2> $ C C$;!"")!+ $%!%
 !% 4 4$'!"
 %)$G$G(?%&&+%- !/
 (6'9 !"&DE\D]]r$s!"'9 !% 4 4$2Q$68O!" //aD2</(RS]R^^_`  0{+q0 UV`Uaalmolppqr0 077;< ?>OOcdh  eA  eA  BD  eE  dF  FR  Sj  Rk  kl  m	< ;;3! 
 ,,SJ)-)L)L3**!  .1 <=T<UUjk1 ,,*Q.0G *F (	=$,C,CC &yk1BCC 77	B7KqPKLO77	B5I (U
d 43349 	
--..QR	
9 4**+s)0
 
 	
 01C8O8O4PPQQtu	
 
r   c                    [        U R                  5      S:X  a  g [        [        U R                  U R
                  5      5       H  nXR                  ;   d  M  U R                  U   u  p#US:X  a:  U R                  X5        [        R                  " SU S35        U R                  U	 Me  US:X  d  Mm  U R                  X5        [        R                  " SU S35        U R                  U	 M     U R                  R                  5         [        U R                  5      S:X  d5   U R                  R                  5        Vs/ s H  oDS   PM	     sn S35       eg s  snf )	Nr   r   r   z] offload op.r   z] offload_sync op.rf   z checkpoints left un-Offloaded)r|   r   r   r   r   r   rn   r   r   ru   rT   r   rN   )r   rs   r   rt   r   s        r   _update_forward"RecomputeOptimizer._update_forward"  s;   t""#q($&&(<(<=
F ,,,-1-@-@-H*		)++FDMMH_,=]"KL++F3&(((AMM"?"33EF ++F3
  	

!!#4&&'1, 	
"&"5"5"<"<">?">3A">?@@^_	
,?s   =Ec                     g rq   rr   r$   s    r   _check_offload_fetch'RecomputeOptimizer._check_offload_fetch:  rw   r   Nc                 ,   UR                   R                  U l        UR                   U l         Uc  [        R                  R                  5       n[        U R                  U5         [        U R                  5      S:  d   SU R                   S35       e[        S U R                   5       5      (       d   SU R                   S35       e0 U l
        0 U l        U R                   H4  nU R                  U5      u  pEUU R                  U'   UU R                  U'   M6     U R                  U5        U R                  5         U R!                  5         U R#                  5         U R%                  5         U R'                  5         SSS5        g! , (       d  f       g= f)z
core steps for recompute offload
1. create pinned vars and temp vars
2. parse & update Forward pass: offload, sync
3. parse & update Backward pass: rename, fetch, sync
4. verify the correctness
Nr   zcheckpoints shape z2 should be an non empty list like: [12, 512, 1024]c              3   *   #    U  H	  oS :  v   M     g7f)r   Nrr   ).0r   s     r   	<genexpr>.RecomputeOptimizer._offload.<locals>.<genexpr>O  s     @*?3Qw*?s   zall ele in checkpoints shape z- should be a determined integer larger than 0)rT   programr8   paddlestaticdefault_startup_programr	   r|   r;   allrM   rg   r   rB   rW   r   r   r   r   r   )r   lossrR   checkpoint_varnamer>   fetch_var_names         r   _offloadRecomputeOptimizer._offload>  sx    "ZZ//ZZ
"$mmCCEO4--?t,,-1 $T%:%:$;;mn1 @$*?*?@@@ /0E0E/FFst@ 02D,.0D+&*&B&B"262C2C&3/ $ 001CD # //0BC 'C **?;   "!!#!  "%%'9 @??s   $DF
Fc                 j   U R                   c   S5       e[        5       (       a  [        S5      eUR                  U l        UR
                  R                  n[        Xb5         / nU R                    HU  n[        U[        5      (       a  UR                  U5        M+  UR                  UR
                  R                  U5      5        MW     [        U5      S:  a  [        UUUUS9u  pO[        UUUUS9n	SSS5        U R                  (       a  W
U l        U R!                  XS9  W	$ ! , (       d  f       N7= f)a&  
call append_backward with checkpoints.

Args:
    loss (Variable): loss variable to run optimizations.
    startup_program (Program): startup_program for initializing parameters
        in `parameter_list`.
    parameter_list (list): list of Variables or Variable.names to update.
    no_grad_set (set|None): set of Variables or Variables.names should be ignored.
    callbacks (list|None): list of callables to run when appending backward
        operator for one parameter.
    checkpoints (list): list of Variables as checkpoints

Examples:
    .. code-block:: python

        >>> import paddle

        >>> paddle.enable_static()

        >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
        ...     fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
        ...     prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
        ...     cost = paddle.nn.functional.cross_entropy(
        ...         input=prediction, label=input_y,
        ...         reduction='none', use_softmax=False
        ...     )
        ...     sum_cost = paddle.mean(cost)
        ...     return sum_cost, fc_1, prediction

        >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
        >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
        >>> cost, fc_1, pred = mlp(input_x, input_y)
        >>> print("Finished FF")
        Finished FF

        >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
        >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
        >>> sgd._set_checkpoints([fc_1, pred])
        >>> params_grads = sgd.backward(
        ...     cost,
        ...     startup_program=None,
        ...     parameter_list=None,
        ...     no_grad_set=None)
        >>> print("Finished backward")
        Finished backward
N&You should call _set_checkpoints first*DyGraph current does not support recomputer   )r   )rR   )r   r   r(   r4   _dtyperT   r   r	   r   r   appendr<   r|   r   r   r   r   )r   r   rR   parameter_listno_grad_set	callbacksr   checkpoint_varsr    r-   r   s              r   backwardRecomputeOptimizer.backwardi  s#   n   , 	
4	
, %<  jj**$$74 O))dH--#**40#**4::>>$+?@	 * ?#a'8G" /	955  /" /	 # 50 +BD(MM$M@9 54s   !BD$$
D2c                     [        U R                  S5      (       a  U R                  R                  OU R                  R                  nU" XUS9$ )a  
call the apply_optimize function of self._optimizer
Args:
    loss (Variable): loss variable to run optimizations.
    startup_program (Program): startup_program for initializing parameters
        in `parameter_list`.
    params_grads (list): list of (param, grad) pair to do optimization.
Examples:
    .. code-block:: python

        >>> import paddle

        >>> paddle.enable_static()

        >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
        ...     fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
        ...     prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
        ...     cost = paddle.nn.functional.cross_entropy(
        ...         input=prediction, label=input_y,
        ...         reduction='none', use_softmax=False
        ...     )
        ...     sum_cost = paddle.mean(cost)
        ...     return sum_cost, fc_1, prediction

        >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
        >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
        >>> cost, fc_1, pred = mlp(input_x, input_y)
        >>> print("Finished FF")
        Finished FF

        >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
        >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
        >>> sgd._set_checkpoints([fc_1, pred])
        >>> params_grads = sgd.backward(
        ...     cost,
        ...     startup_program=None,
        ...     parameter_list=None,
        ...     no_grad_set=None)

        >>> optimize_ops = sgd.apply_optimize(
        ...     cost, startup_program=None, params_grads=params_grads)

        >>> print("Finished apply_optimize")
        Finished apply_optimize
apply_optimizerR   r-   )hasattrr   r   _apply_optimize)r   r   rR   r-   funcs        r   r   !RecomputeOptimizer.apply_optimize  sL    b t(899 OO**00 	
 
 	
r   c                     [        U[        5      (       d   S5       eU R                  c   S5       e[        5       (       a  [	        S5      eU R                  UUUUS9nU R                  XUS9nXe4$ )NzThe loss should be an Variable.r   r   )rR   r   r   r   )r   r   r   r   r(   r   r   )r   r   rR   r   r   r-   optimize_opss          r   minimizeRecomputeOptimizer.minimize  s     $))L+LL)  , 	
4	
, %<  }}+)#	 % 
 ** + 
 ))r   )r   r   r   r   r8   r   rT   r   rg   rM   r   r   r   r   r   r   r   r}   r   rq   )NNNN)NNN)__name__
__module____qualname____firstlineno____doc__r   r!   r%   r   deprecate_stat_dictr*   r.   rB   rW   rc   rk   rn   ru   r   r   r   r   r   r   r   r   r   r   r   r   __static_attributes__rr   r   r   r   r      s    Ob$(# "",
 #,
\3Jj1,!F
OI@=<
|
&q
f
0)(\ ^@6
r LP*r   r   )r   r   paddle.baser   r   r   paddle.base.backwardr   paddle.base.frameworkr   r   r	   paddle.optimizerr
   r   rr   r   r   <module>r      s-      4 4 0 J J &* *r   