Skip to content

Commit

Permalink
Merge pull request #2399 from anutosh491/Second_method_for_args
Browse files Browse the repository at this point in the history
Implementing get argument through design 1
  • Loading branch information
certik authored Oct 31, 2023
2 parents 5072c12 + 5455dd0 commit da8d25b
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 3 deletions.
4 changes: 4 additions & 0 deletions integration_tests/symbolics_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_symbolic_operations():
else:
assert False
assert(z.func == Add)
assert(z.args[0] == x or z.args[0] == y)
assert(z.args[1] == y or z.args[1] == x)
print(z)

# Subtraction
Expand All @@ -43,6 +45,8 @@ def test_symbolic_operations():
else:
assert False
assert(u.func == Mul)
assert(u.args[0] == x)
assert(u.args[1] == y)
print(u)

# Division
Expand Down
7 changes: 7 additions & 0 deletions integration_tests/symbolics_05.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,12 @@ def test_operations():
assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d)
assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d)

# test args
assert(a.args[0] == x + y)
assert(a.args[1] == S(2))
assert(b.args[0] == x + y + z)
assert(b.args[1] == S(3))
assert(c.args[0] == x)
assert(d.args[0] == x)

test_operations()
55 changes: 55 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ enum class IntrinsicScalarFunctions : int64_t {
SymbolicPowQ,
SymbolicLogQ,
SymbolicSinQ,
SymbolicGetArgument,
// ...
};

Expand Down Expand Up @@ -152,6 +153,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicPowQ)
INTRINSIC_NAME_CASE(SymbolicLogQ)
INTRINSIC_NAME_CASE(SymbolicSinQ)
INTRINSIC_NAME_CASE(SymbolicGetArgument)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
}
Expand Down Expand Up @@ -3116,6 +3118,54 @@ namespace SymbolicHasSymbolQ {
}
} // namespace SymbolicHasSymbolQ

namespace SymbolicGetArgument {
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x,
diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicGetArgument"
"accepts exactly 2 argument", x.base.base.loc, diagnostics);

ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]);
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type),
"SymbolicGetArgument expects the first argument to be of type SymbolicExpression",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::Integer_t>(*arg2_type),
"SymbolicGetArgument expects the second argument to be of type Integer",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t* eval_SymbolicGetArgument(Allocator &/*al*/,
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) {
/*TODO*/
return nullptr;
}

static inline ASR::asr_t* create_SymbolicGetArgument(Allocator& al,
const Location& loc, Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {

if (args.size() != 2) {
err("Intrinsic function SymbolicGetArguments accepts exactly 2 argument", loc);
}

ASR::ttype_t* arg1_type = ASRUtils::expr_type(args[0]);
ASR::ttype_t* arg2_type = ASRUtils::expr_type(args[1]);
if (!ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type)) {
err("The first argument of SymbolicGetArgument function must be of type SymbolicExpression",
args[0]->base.loc);
}
if (!ASR::is_a<ASR::Integer_t>(*arg2_type)) {
err("The second argument of SymbolicGetArgument function must be of type Integer",
args[1]->base.loc);
}

ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicGetArgument,
static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
0, to_type);
}
} // namespace SymbolicGetArgument

#define create_symbolic_query_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
Expand Down Expand Up @@ -3325,6 +3375,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &SymbolicLogQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
{nullptr, &SymbolicSinQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
{nullptr, &SymbolicGetArgument::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -3441,6 +3493,8 @@ namespace IntrinsicScalarFunctionRegistry {
"SymbolicLogQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
"SymbolicSinQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
"SymbolicGetArgument"},
};


Expand Down Expand Up @@ -3502,6 +3556,7 @@ namespace IntrinsicScalarFunctionRegistry {
{"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
{"LogQ", {&SymbolicLogQ::create_SymbolicLogQ, &SymbolicLogQ::eval_SymbolicLogQ}},
{"SinQ", {&SymbolicSinQ::create_SymbolicSinQ, &SymbolicSinQ::eval_SymbolicSinQ}},
{"GetArgument", {&SymbolicGetArgument::create_SymbolicGetArgument, &SymbolicGetArgument::eval_SymbolicGetArgument}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down
Loading

0 comments on commit da8d25b

Please sign in to comment.