Skip to content

Commit

Permalink
Merge pull request #2404 from anutosh491/Fixing_assert
Browse files Browse the repository at this point in the history
Added support for `visit_Assert` through `basic_eq`
  • Loading branch information
certik authored Nov 2, 2023
2 parents 35a480b + 06e0a80 commit dcaf253
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
6 changes: 6 additions & 0 deletions integration_tests/symbolics_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,10 @@ def main0():
assert(z == pi + y)
assert(z != S(2)*pi + y)

# testing PR 2404
p: S = Symbol('pi')
print(p)
print(p != pi)
assert(p != pi)

main0()
34 changes: 25 additions & 9 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,16 +1706,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);

ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym);
right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym);
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
s->m_op, right_tmp, s->m_type, s->m_value));
ASR::SymbolicCompare_t* s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
ASR::symbol_t* sym = nullptr;
if (s->m_op == ASR::cmpopType::Eq) {
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
} else {
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
}
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = value1;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value2;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
sym, sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));

ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, function_call, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
Expand Down

0 comments on commit dcaf253

Please sign in to comment.