diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index c50ea956daaa57..01d29fd2735a99 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -689,3 +689,67 @@ def get_split_op(value): @lru_cache def warning_once(message: str): logging.warning(message) + + +def update_if_output_stopgradient(if_op, true_yield_op, false_yield_op): + """ + Update if_op's stop_gradient based on true_yield_op and false_yield_op. + + Args: + true_yield_op: true block of if_op's last op. + false_yield_op: false block of if_op's last op. + if_op: update it's op_results()'s stop_gradient. + """ + if ( + true_yield_op.name() != 'cf.yield' + or false_yield_op.name() != 'cf.yield' + ): + raise ValueError("param isnot yield op") + + # Check if operands_source sizes match + if len(true_yield_op.operands_source()) != len( + false_yield_op.operands_source() + ): + raise ValueError("Mismatched yield operands_source sizes") + + # Check if op_results size matches operands_source + if len(if_op.results()) != len(true_yield_op.operands_source()): + raise ValueError( + "Mismatched if op_results size with yield operands_source" + ) + + # Update if_op's stop_gradient + for i in range(len(true_yield_op.operands_source())): + stop_grad1 = true_yield_op.operand_source(i).stop_gradient + stop_grad2 = false_yield_op.operand_source(i).stop_gradient + + # Set to False if either stop_gradient is False + if not stop_grad1 or not stop_grad2: + if_op.result(i).stop_gradient = False + + +def update_while_output_stopgradient(while_op, yield_op): + """ + Update while_op's stop_gradient based on yield_op. + + Args: + yield_op: The yield operation associated with the while loop. + while_op: The while operation whose op_results()'s stop_gradient needs to be updated. + """ + # Check if yield_op is indeed a yield operation + if yield_op.name() != 'cf.yield': + raise ValueError("yield_op is not a yield operation") + + # Check if operands_source size of yield_op matches op_results size of while_op + if len(while_op.results()) + 1 != len(yield_op.operands_source()): + raise ValueError( + f"Mismatched while op_results size %d with yield operands_source %d. {len(while_op.results()) + 1, len(yield_op.operands_source())}" + ) + + # Update while_op's stop_gradient + for i in range(1, len(yield_op.operands_source())): + stop_grad = yield_op.operand_source(i).stop_gradient + + # Set to False if stop_gradient is False + if not stop_grad: + while_op.result(i - 1).stop_gradient = False diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 587096d7ef00a7..9bde3a05b3a89a 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -45,7 +45,9 @@ return_map_value, return_map_value_list, some_in_set, + update_if_output_stopgradient, update_no_grad_set_by_stopgradient, + update_while_output_stopgradient, warning_once, while_prune_check, ) @@ -791,6 +793,11 @@ def append_yield( input_tuple[1] ) + update_if_output_stopgradient( + grad_op, + grad_op.as_if_op().true_block().ops[-1], + grad_op.as_if_op().false_block().ops[-1], + ) for input_tuple in inputs_used_by_other_op: state.value_to_valuegrad[input_tuple[0]] = [] # update input_grad map @@ -870,6 +877,10 @@ def append_yield( sub_bwd_value_to_block_argument_map, sub_control_flow_value_to_copyvalue_map, ) + + update_while_output_stopgradient( + grad_op, while_grad_block.ops[-1] + ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) elif op.name() == "pd_op.pylayer":