From ff5a06f43acc4c6783a25fa94b6f53bb09dc37eb Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 24 Dec 2024 11:09:52 +0800 Subject: [PATCH 1/5] generate_vjp optional output add check --- .../vjp_interface/generated/generated_vjp.cc.j2 | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/primitive/codegen/templates/vjp_interface/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/vjp_interface/generated/generated_vjp.cc.j2 index 7882abcb5aff1d..806e7cf5479ae3 100644 --- a/paddle/fluid/primitive/codegen/templates/vjp_interface/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/vjp_interface/generated/generated_vjp.cc.j2 @@ -96,13 +96,17 @@ auto op_res = backend::{{api.name}}({{common.args(input_names, attr_ {% for i in range(outputs|length) %} {% if outputs[i].typename=='Tensor' %} {% if outputs[i].optional %} -vjp_res[{{i}}][0] = std::get<{{i}}>(op_res).get(); +if(std::get<{{i}}>(op_res)){ + vjp_res[{{i}}][0] = std::get<{{i}}>(op_res).get(); +} {% else %} vjp_res[{{i}}][0] = std::get<{{i}}>(op_res); {% endif %} {% else %} {% if outputs[i].optional %} -vjp_res[{{i}}] = std::get<{{i}}>(op_res).get(); +if(std::get<{{i}}>(op_res)){ + vjp_res[{{i}}] = std::get<{{i}}>(op_res).get(); +} {% else %} vjp_res[{{i}}] = std::get<{{i}}>(op_res); {% endif %} @@ -111,13 +115,17 @@ vjp_res[{{i}}] = std::get<{{i}}>(op_res); {% elif outputs|length == 1 %} {% if outputs[0].typename=='Tensor' %} {% if outputs[0].optional %} -vjp_res[0][0] = op_res.get(); +if(op_res){ + vjp_res[0][0] = op_res.get(); +} {% else %} vjp_res[0][0] = op_res; {% endif %} {% else %} {% if outputs[0].optional %} -vjp_res[0] = op_res.get(); +if(op_res){ + vjp_res[0] = op_res.get(); +} {% else %} vjp_res[0] = op_res; {% endif %} From 76f679e2984e368a205448f4407e7c655bad2c12 Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 30 Dec 2024 14:44:29 +0800 Subject: [PATCH 2/5] add if while grad op stopgradient update func --- python/paddle/autograd/backward_utils.py | 69 ++++++++++++++++++++++++ python/paddle/autograd/ir_backward.py | 15 +++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index c50ea956daaa57..736743650c47e8 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -689,3 +689,72 @@ def get_split_op(value): @lru_cache def warning_once(message: str): logging.warning(message) + + +def update_if_output_stopgradient(if_op, true_yield_op, false_yield_op): + """ + Update if_op's stop_gradient based on true_yield_op and false_yield_op. + + Args: + true_yield_op: true block of if_op's last op. + false_yield_op: false block of if_op's last op. + if_op: update it's op_results()'s stop_gradient. + """ + if ( + true_yield_op.name() != 'cf.yield' + or false_yield_op.name() != 'cf.yield' + ): + raise ValueError("param isnot yield op") + + # Check if operands_source sizes match + if ( + true_yield_op.operands_source().size() + != false_yield_op.operands_source().size() + ): + raise ValueError("Mismatched yield operands_source sizes") + + # Check if op_results size matches operands_source + if if_op.op_results().size() != true_yield_op.operands_source().size(): + raise ValueError( + "Mismatched if op_results size with yield operands_source" + ) + + # Update if_op's stop_gradient + for i in range(true_yield_op.operands_source().size()): + stop_grad1 = true_yield_op.operand_source(i).stop_gradient + stop_grad2 = false_yield_op.operand_source(i).stop_gradient + + # Set to False if either stop_gradient is False + if not stop_grad1 or not stop_grad2: + if_op.op_result(i).stop_gradient = False + + +def update_while_output_stopgradient(while_op, yield_op): + """ + Update while_op's stop_gradient based on yield_op. + + Args: + yield_op: The yield operation associated with the while loop. + while_op: The while operation whose op_results()'s stop_gradient needs to be updated. + """ + # Check if yield_op is indeed a yield operation + if yield_op.name() != 'cf.yield': + raise ValueError("yield_op is not a yield operation") + + # Check if operands_source size of yield_op matches op_results size of while_op + if while_op.op_results().size() != yield_op.operands_source().size(): + raise ValueError( + "Mismatched while op_results size with yield operands_source" + ) + + # Update while_op's stop_gradient + for i in range(yield_op.operands_source().size()): + stop_grad = yield_op.operand_source(i).stop_gradient + + # Set to False if stop_gradient is False + if not stop_grad: + while_op.op_result(i).stop_gradient = False + + +# Example usage (assuming appropriate classes and methods) +# update_while_output_stopgradient(while_op, yield_op) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 587096d7ef00a7..57173d53f40d6b 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -45,7 +45,9 @@ return_map_value, return_map_value_list, some_in_set, + update_if_output_stopgradient, update_no_grad_set_by_stopgradient, + update_while_output_stopgradient, warning_once, while_prune_check, ) @@ -791,6 +793,11 @@ def append_yield( input_tuple[1] ) + update_if_output_stopgradient( + grad_op, + grad_op.true_block().ops[-1], + grad_op.false_block().ops[-1], + ) for input_tuple in inputs_used_by_other_op: state.value_to_valuegrad[input_tuple[0]] = [] # update input_grad map @@ -870,6 +877,10 @@ def append_yield( sub_bwd_value_to_block_argument_map, sub_control_flow_value_to_copyvalue_map, ) + + update_while_output_stopgradient( + grad_op, while_grad_block.ops[-1] + ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) elif op.name() == "pd_op.pylayer": @@ -1227,7 +1238,7 @@ def calc_gradient( grad_outputs=grad_outputs, no_grad_set=ValueSet(no_grad_set), ) - + print(paddle.static.default_main_program()) inputgrad = [] for input in inputs: inputgrad.append( @@ -1402,7 +1413,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): grad_outputs=[], no_grad_set=ValueSet(no_grad_set_), ) - + print(paddle.static.default_main_program()) input_inputs_grad = [] for input in parameter_list: input_inputs_grad.append( From 84c89b294111f552b5a5fd8b4f14c57bdcc75f0b Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 30 Dec 2024 16:16:30 +0800 Subject: [PATCH 3/5] modify bug --- python/paddle/autograd/backward_utils.py | 21 ++++++++------------- python/paddle/autograd/ir_backward.py | 4 ++-- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 736743650c47e8..c8188271a52d8f 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -707,26 +707,25 @@ def update_if_output_stopgradient(if_op, true_yield_op, false_yield_op): raise ValueError("param isnot yield op") # Check if operands_source sizes match - if ( - true_yield_op.operands_source().size() - != false_yield_op.operands_source().size() + if len(true_yield_op.operands_source()) != len( + false_yield_op.operands_source() ): raise ValueError("Mismatched yield operands_source sizes") # Check if op_results size matches operands_source - if if_op.op_results().size() != true_yield_op.operands_source().size(): + if len(if_op.results()) != len(true_yield_op.operands_source()): raise ValueError( "Mismatched if op_results size with yield operands_source" ) # Update if_op's stop_gradient - for i in range(true_yield_op.operands_source().size()): + for i in range(len(true_yield_op.operands_source())): stop_grad1 = true_yield_op.operand_source(i).stop_gradient stop_grad2 = false_yield_op.operand_source(i).stop_gradient # Set to False if either stop_gradient is False if not stop_grad1 or not stop_grad2: - if_op.op_result(i).stop_gradient = False + if_op.result(i).stop_gradient = False def update_while_output_stopgradient(while_op, yield_op): @@ -742,19 +741,15 @@ def update_while_output_stopgradient(while_op, yield_op): raise ValueError("yield_op is not a yield operation") # Check if operands_source size of yield_op matches op_results size of while_op - if while_op.op_results().size() != yield_op.operands_source().size(): + if len(while_op.results()) != len(yield_op.operands_source()): raise ValueError( "Mismatched while op_results size with yield operands_source" ) # Update while_op's stop_gradient - for i in range(yield_op.operands_source().size()): + for i in range(1, len(yield_op.operands_source())): stop_grad = yield_op.operand_source(i).stop_gradient # Set to False if stop_gradient is False if not stop_grad: - while_op.op_result(i).stop_gradient = False - - -# Example usage (assuming appropriate classes and methods) -# update_while_output_stopgradient(while_op, yield_op) + while_op.result(i - 1).stop_gradient = False diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 57173d53f40d6b..c4d276caa73dea 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -795,8 +795,8 @@ def append_yield( update_if_output_stopgradient( grad_op, - grad_op.true_block().ops[-1], - grad_op.false_block().ops[-1], + grad_op.as_if_op().true_block().ops[-1], + grad_op.as_if_op().false_block().ops[-1], ) for input_tuple in inputs_used_by_other_op: state.value_to_valuegrad[input_tuple[0]] = [] From 3c57e8d5667ef039073b14baba516e1d5068de6a Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 31 Dec 2024 10:13:30 +0800 Subject: [PATCH 4/5] fix while op update --- python/paddle/autograd/backward_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index c8188271a52d8f..01d29fd2735a99 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -741,9 +741,9 @@ def update_while_output_stopgradient(while_op, yield_op): raise ValueError("yield_op is not a yield operation") # Check if operands_source size of yield_op matches op_results size of while_op - if len(while_op.results()) != len(yield_op.operands_source()): + if len(while_op.results()) + 1 != len(yield_op.operands_source()): raise ValueError( - "Mismatched while op_results size with yield operands_source" + f"Mismatched while op_results size %d with yield operands_source %d. {len(while_op.results()) + 1, len(yield_op.operands_source())}" ) # Update while_op's stop_gradient From cc43f3e79687ca3a4bd954588e6b0bc716a3751c Mon Sep 17 00:00:00 2001 From: wangruting Date: Tue, 31 Dec 2024 14:10:30 +0800 Subject: [PATCH 5/5] delete print --- python/paddle/autograd/ir_backward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index c4d276caa73dea..9bde3a05b3a89a 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -1238,7 +1238,7 @@ def calc_gradient( grad_outputs=grad_outputs, no_grad_set=ValueSet(no_grad_set), ) - print(paddle.static.default_main_program()) + inputgrad = [] for input in inputs: inputgrad.append( @@ -1413,7 +1413,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): grad_outputs=[], no_grad_set=ValueSet(no_grad_set_), ) - print(paddle.static.default_main_program()) + input_inputs_grad = [] for input in parameter_list: input_inputs_grad.append(