Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】update If_grad and while_grad op' stop gradient by yield_op #70545

Merged
Merged
64 changes: 64 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,67 @@ def get_split_op(value):
@lru_cache
def warning_once(message: str):
logging.warning(message)


def update_if_output_stopgradient(if_op, true_yield_op, false_yield_op):
"""
Update if_op's stop_gradient based on true_yield_op and false_yield_op.

Args:
true_yield_op: true block of if_op's last op.
false_yield_op: false block of if_op's last op.
if_op: update it's op_results()'s stop_gradient.
"""
if (
true_yield_op.name() != 'cf.yield'
or false_yield_op.name() != 'cf.yield'
):
raise ValueError("param isnot yield op")

# Check if operands_source sizes match
if len(true_yield_op.operands_source()) != len(
false_yield_op.operands_source()
):
raise ValueError("Mismatched yield operands_source sizes")

# Check if op_results size matches operands_source
if len(if_op.results()) != len(true_yield_op.operands_source()):
raise ValueError(
"Mismatched if op_results size with yield operands_source"
)

# Update if_op's stop_gradient
for i in range(len(true_yield_op.operands_source())):
stop_grad1 = true_yield_op.operand_source(i).stop_gradient
stop_grad2 = false_yield_op.operand_source(i).stop_gradient

# Set to False if either stop_gradient is False
if not stop_grad1 or not stop_grad2:
if_op.result(i).stop_gradient = False


def update_while_output_stopgradient(while_op, yield_op):
"""
Update while_op's stop_gradient based on yield_op.

Args:
yield_op: The yield operation associated with the while loop.
while_op: The while operation whose op_results()'s stop_gradient needs to be updated.
"""
# Check if yield_op is indeed a yield operation
if yield_op.name() != 'cf.yield':
raise ValueError("yield_op is not a yield operation")

# Check if operands_source size of yield_op matches op_results size of while_op
if len(while_op.results()) + 1 != len(yield_op.operands_source()):
raise ValueError(
f"Mismatched while op_results size %d with yield operands_source %d. {len(while_op.results()) + 1, len(yield_op.operands_source())}"
)

# Update while_op's stop_gradient
for i in range(1, len(yield_op.operands_source())):
stop_grad = yield_op.operand_source(i).stop_gradient

# Set to False if stop_gradient is False
if not stop_grad:
while_op.result(i - 1).stop_gradient = False
11 changes: 11 additions & 0 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
return_map_value,
return_map_value_list,
some_in_set,
update_if_output_stopgradient,
update_no_grad_set_by_stopgradient,
update_while_output_stopgradient,
warning_once,
while_prune_check,
)
Expand Down Expand Up @@ -791,6 +793,11 @@ def append_yield(
input_tuple[1]
)

update_if_output_stopgradient(
grad_op,
grad_op.as_if_op().true_block().ops[-1],
grad_op.as_if_op().false_block().ops[-1],
)
for input_tuple in inputs_used_by_other_op:
state.value_to_valuegrad[input_tuple[0]] = []
# update input_grad map
Expand Down Expand Up @@ -870,6 +877,10 @@ def append_yield(
sub_bwd_value_to_block_argument_map,
sub_control_flow_value_to_copyvalue_map,
)

update_while_output_stopgradient(
grad_op, while_grad_block.ops[-1]
)
# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
elif op.name() == "pd_op.pylayer":
Expand Down
Loading