Skip to content

Commit

Permalink
Implementing get argument through design 1
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 committed Oct 28, 2023
1 parent 3ed3098 commit fb809f3
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ enum class IntrinsicScalarFunctions : int64_t {
SymbolicPowQ,
SymbolicLogQ,
SymbolicSinQ,
SymbolicGetArgument,
// ...
};

Expand Down Expand Up @@ -152,6 +153,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicPowQ)
INTRINSIC_NAME_CASE(SymbolicLogQ)
INTRINSIC_NAME_CASE(SymbolicSinQ)
INTRINSIC_NAME_CASE(SymbolicGetArgument)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
}
Expand Down Expand Up @@ -3116,6 +3118,54 @@ namespace SymbolicHasSymbolQ {
}
} // namespace SymbolicHasSymbolQ

namespace SymbolicGetArgument {
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x,
diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicGetArgument"
"accepts exactly 2 argument", x.base.base.loc, diagnostics);

ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]);
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type),
"SymbolicGetArgument expects the first argument to be of type SymbolicExpression",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::Integer_t>(*arg2_type),
"SymbolicGetArgument expects the second argument to be of type Integer",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t* eval_SymbolicGetArgument(Allocator &/*al*/,
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) {
/*TODO*/
return nullptr;
}

static inline ASR::asr_t* create_SymbolicGetArgument(Allocator& al,
const Location& loc, Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {

if (args.size() != 2) {
err("Intrinsic function SymbolicGetArguments accepts exactly 2 argument", loc);
}

ASR::ttype_t* arg1_type = ASRUtils::expr_type(args[0]);
ASR::ttype_t* arg2_type = ASRUtils::expr_type(args[1]);
if (!ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type)) {
err("The first argument of SymbolicGetArgument function must be of type SymbolicExpression",
args[0]->base.loc);
}
if (!ASR::is_a<ASR::Integer_t>(*arg2_type)) {
err("The second argument of SymbolicGetArgument function must be of type Integer",
args[1]->base.loc);
}

ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicGetArgument,
static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
0, to_type);
}
} // namespace SymbolicGetArgument

#define create_symbolic_query_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
Expand Down Expand Up @@ -3325,6 +3375,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &SymbolicLogQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
{nullptr, &SymbolicSinQ::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
{nullptr, &SymbolicGetArgument::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -3441,6 +3493,8 @@ namespace IntrinsicScalarFunctionRegistry {
"SymbolicLogQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
"SymbolicSinQ"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
"SymbolicGetArgument"},
};


Expand Down Expand Up @@ -3502,6 +3556,7 @@ namespace IntrinsicScalarFunctionRegistry {
{"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
{"LogQ", {&SymbolicLogQ::create_SymbolicLogQ, &SymbolicLogQ::eval_SymbolicLogQ}},
{"SinQ", {&SymbolicSinQ::create_SymbolicSinQ, &SymbolicSinQ::eval_SymbolicSinQ}},
{"GetArgument", {&SymbolicGetArgument::create_SymbolicGetArgument, &SymbolicGetArgument::eval_SymbolicGetArgument}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down
234 changes: 234 additions & 0 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,84 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
process_unary_operator(al, loc, x, module_scope, "basic_expand", target);
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicGetArgument: {
// Define necessary function symbols
ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]);
ASR::symbol_t* basic_get_args_sym = declare_basic_get_args_function(al, loc, module_scope);
ASR::symbol_t* vecbasic_new_sym = declare_vecbasic_new_function(al, loc, module_scope);
ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope);
ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope);

// Define necessary variables
ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
ASR::symbol_t* args_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, current_scope, s2c(al, "args"), nullptr, 0, ASR::intentType::Local,
nullptr, nullptr, ASR::storage_typeType::Default, CPtr_type, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
//if (!current_scope->get_symbol("args")) {
current_scope->add_symbol(s2c(al, "args"), args_sym);
//}

// Statement 1
ASR::expr_t* args = ASRUtils::EXPR(ASR::make_Var_t(al, loc, args_sym));
Vec<ASR::call_arg_t> call_args1;
call_args1.reserve(al, 1);
ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
vecbasic_new_sym, vecbasic_new_sym, call_args1.p, call_args1.n,
ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr));
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, args, function_call1, nullptr));
pass_result.push_back(al, stmt1);

// Statement 2
Vec<ASR::call_arg_t> call_args2;
call_args2.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = loc;
call_arg1.m_value = value1;
call_arg2.loc = loc;
call_arg2.m_value = args;
call_args2.push_back(al, call_arg1);
call_args2.push_back(al, call_arg2);
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_get_args_sym,
basic_get_args_sym, call_args2.p, call_args2.n, nullptr));
pass_result.push_back(al, stmt2);

// Statement 3
Vec<ASR::stmt_t *> if_body; if_body.reserve(al, 1);
Vec<ASR::stmt_t *> else_body; else_body.reserve(al, 1);
Vec<ASR::call_arg_t> call_args3;
call_args3.reserve(al, 1);
ASR::call_arg_t call_arg3;
call_arg3.loc = loc;
call_arg3.m_value = args;
call_args3.push_back(al, call_arg3);
ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
vecbasic_size_sym, vecbasic_size_sym, call_args3.p, call_args3.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* int_compare = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call2, ASR::cmpopType::LtE,
x->m_args[1], ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
ASR::stmt_t* stmt3 = ASRUtils::STMT(make_If_t(al, loc, int_compare,
if_body.p, if_body.n, else_body.p, else_body.n));
pass_result.push_back(al, stmt3);

// Statement 4
Vec<ASR::call_arg_t> call_args4;
call_args4.reserve(al, 3);
ASR::call_arg_t call_arg4, call_arg5, call_arg6;
call_arg4.loc = loc;
call_arg4.m_value = args;
call_arg5.loc = loc;
call_arg5.m_value = x->m_args[1];
call_arg6.loc = loc;
call_arg6.m_value = target;
call_args4.push_back(al, call_arg4);
call_args4.push_back(al, call_arg5);
call_args4.push_back(al, call_arg6);
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym,
vecbasic_get_sym, call_args4.p, call_args4.n, nullptr));
pass_result.push_back(al, stmt4);
break;
}
default: {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(intrinsic_id)
Expand Down Expand Up @@ -723,6 +801,162 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_get_args_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_get_args";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_vecbasic_new_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "vecbasic_new";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_vecbasic_get_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "vecbasic_get";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 3);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "z"), arg3);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_vecbasic_size_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "vecbasic_size";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_eq";
symbolic_dependencies.push_back(name);
Expand Down
Loading

0 comments on commit fb809f3

Please sign in to comment.