From fb809f3ae40bc8ca0f3d1c692de75af72ed315c2 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 28 Oct 2023 21:04:02 +0530 Subject: [PATCH] Implementing get argument through design 1 --- src/libasr/pass/intrinsic_function_registry.h | 55 ++++ src/libasr/pass/replace_symbolic.cpp | 234 ++++++++++++++++++ src/lpython/semantics/python_ast_to_asr.cpp | 24 ++ src/lpython/semantics/python_attribute_eval.h | 14 ++ 4 files changed, 327 insertions(+) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index ff4c92548e..9810776d8b 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -85,6 +85,7 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicPowQ, SymbolicLogQ, SymbolicSinQ, + SymbolicGetArgument, // ... }; @@ -152,6 +153,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicPowQ) INTRINSIC_NAME_CASE(SymbolicLogQ) INTRINSIC_NAME_CASE(SymbolicSinQ) + INTRINSIC_NAME_CASE(SymbolicGetArgument) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -3116,6 +3118,54 @@ namespace SymbolicHasSymbolQ { } } // namespace SymbolicHasSymbolQ +namespace SymbolicGetArgument { + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, + diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicGetArgument" + "accepts exactly 2 argument", x.base.base.loc, diagnostics); + + ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]); + ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]); + ASRUtils::require_impl(ASR::is_a(*arg1_type), + "SymbolicGetArgument expects the first argument to be of type SymbolicExpression", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*arg2_type), + "SymbolicGetArgument expects the second argument to be of type Integer", + x.base.base.loc, diagnostics); + } + + static inline ASR::expr_t* eval_SymbolicGetArgument(Allocator &/*al*/, + const Location &/*loc*/, ASR::ttype_t *, Vec &/*args*/) { + /*TODO*/ + return nullptr; + } + + static inline ASR::asr_t* create_SymbolicGetArgument(Allocator& al, + const Location& loc, Vec& args, + const std::function err) { + + if (args.size() != 2) { + err("Intrinsic function SymbolicGetArguments accepts exactly 2 argument", loc); + } + + ASR::ttype_t* arg1_type = ASRUtils::expr_type(args[0]); + ASR::ttype_t* arg2_type = ASRUtils::expr_type(args[1]); + if (!ASR::is_a(*arg1_type)) { + err("The first argument of SymbolicGetArgument function must be of type SymbolicExpression", + args[0]->base.loc); + } + if (!ASR::is_a(*arg2_type)) { + err("The second argument of SymbolicGetArgument function must be of type Integer", + args[1]->base.loc); + } + + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicGetArgument, + static_cast(IntrinsicScalarFunctions::SymbolicGetArgument), + 0, to_type); + } +} // namespace SymbolicGetArgument + #define create_symbolic_query_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3325,6 +3375,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicLogQ::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicSinQ), {nullptr, &SymbolicSinQ::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicGetArgument), + {nullptr, &SymbolicGetArgument::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3441,6 +3493,8 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicLogQ"}, {static_cast(IntrinsicScalarFunctions::SymbolicSinQ), "SymbolicSinQ"}, + {static_cast(IntrinsicScalarFunctions::SymbolicGetArgument), + "SymbolicGetArgument"}, }; @@ -3502,6 +3556,7 @@ namespace IntrinsicScalarFunctionRegistry { {"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}}, {"LogQ", {&SymbolicLogQ::create_SymbolicLogQ, &SymbolicLogQ::eval_SymbolicLogQ}}, {"SinQ", {&SymbolicSinQ::create_SymbolicSinQ, &SymbolicSinQ::eval_SymbolicSinQ}}, + {"GetArgument", {&SymbolicGetArgument::create_SymbolicGetArgument, &SymbolicGetArgument::eval_SymbolicGetArgument}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 3bae7436f8..549f16f70c 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -559,6 +559,84 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]); + ASR::symbol_t* basic_get_args_sym = declare_basic_get_args_function(al, loc, module_scope); + ASR::symbol_t* vecbasic_new_sym = declare_vecbasic_new_function(al, loc, module_scope); + ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope); + ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope); + + // Define necessary variables + ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)); + ASR::symbol_t* args_sym = ASR::down_cast(ASR::make_Variable_t( + al, loc, current_scope, s2c(al, "args"), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, CPtr_type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + //if (!current_scope->get_symbol("args")) { + current_scope->add_symbol(s2c(al, "args"), args_sym); + //} + + // Statement 1 + ASR::expr_t* args = ASRUtils::EXPR(ASR::make_Var_t(al, loc, args_sym)); + Vec call_args1; + call_args1.reserve(al, 1); + ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + vecbasic_new_sym, vecbasic_new_sym, call_args1.p, call_args1.n, + ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr)); + ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, args, function_call1, nullptr)); + pass_result.push_back(al, stmt1); + + // Statement 2 + Vec call_args2; + call_args2.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = value1; + call_arg2.loc = loc; + call_arg2.m_value = args; + call_args2.push_back(al, call_arg1); + call_args2.push_back(al, call_arg2); + ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_get_args_sym, + basic_get_args_sym, call_args2.p, call_args2.n, nullptr)); + pass_result.push_back(al, stmt2); + + // Statement 3 + Vec if_body; if_body.reserve(al, 1); + Vec else_body; else_body.reserve(al, 1); + Vec call_args3; + call_args3.reserve(al, 1); + ASR::call_arg_t call_arg3; + call_arg3.loc = loc; + call_arg3.m_value = args; + call_args3.push_back(al, call_arg3); + ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + vecbasic_size_sym, vecbasic_size_sym, call_args3.p, call_args3.n, + ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr)); + ASR::expr_t* int_compare = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call2, ASR::cmpopType::LtE, + x->m_args[1], ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr)); + ASR::stmt_t* stmt3 = ASRUtils::STMT(make_If_t(al, loc, int_compare, + if_body.p, if_body.n, else_body.p, else_body.n)); + pass_result.push_back(al, stmt3); + + // Statement 4 + Vec call_args4; + call_args4.reserve(al, 3); + ASR::call_arg_t call_arg4, call_arg5, call_arg6; + call_arg4.loc = loc; + call_arg4.m_value = args; + call_arg5.loc = loc; + call_arg5.m_value = x->m_args[1]; + call_arg6.loc = loc; + call_arg6.m_value = target; + call_args4.push_back(al, call_arg4); + call_args4.push_back(al, call_arg5); + call_args4.push_back(al, call_arg6); + ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym, + vecbasic_get_sym, call_args4.p, call_args4.n, nullptr)); + pass_result.push_back(al, stmt4); + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) @@ -723,6 +801,162 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_get_args_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_get_args"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + + ASR::symbol_t* declare_vecbasic_new_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "vecbasic_new"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + + ASR::symbol_t* declare_vecbasic_get_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "vecbasic_get"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 3); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "z"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + + ASR::symbol_t* declare_vecbasic_size_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "vecbasic_size"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { std::string name = "basic_eq"; symbolic_dependencies.push_back(name); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index d1096a83ea..cc0a43fdd4 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -369,6 +369,7 @@ class CommonVisitor : public AST::BaseVisitor { bool allow_implicit_casting; // Stores the name of imported functions and the modules they are imported from std::map imported_functions; + bool using_args_attr = false; std::map numpy2lpythontypes = { {"bool", "bool"}, @@ -3802,6 +3803,25 @@ class CommonVisitor : public AST::BaseVisitor { void visit_Subscript(const AST::Subscript_t &x) { this->visit_expr(*x.m_value); + if (using_args_attr) { + if (AST::is_a(*x.m_value)){ + AST::Attribute_t *attr = AST::down_cast(x.m_value); + if (AST::is_a(*attr->m_value)) { + AST::Name_t *var_name = AST::down_cast(attr->m_value); + std::string var = var_name->m_id; + ASR::symbol_t *st = current_scope->resolve_symbol(var); + ASR::expr_t *se = ASR::down_cast( + ASR::make_Var_t(al, x.base.base.loc, st)); + Vec args; + args.reserve(al, 0); + this->visit_expr(*x.m_slice); + ASR::expr_t *index = ASRUtils::EXPR(tmp); + args.push_back(al, index); + tmp = attr_handler.eval_symbolic_get_argument(se, al, x.base.base.loc, args, diag); + return; + } + } + } ASR::expr_t *value = ASRUtils::EXPR(tmp); ASR::ttype_t *type = ASRUtils::expr_type(value); Vec args; @@ -5822,6 +5842,10 @@ class BodyVisitor : public CommonVisitor { using_func_attr = true; return; } + if (attr == "args") { + using_args_attr = true; + return; + } ASR::expr_t *se = ASR::down_cast(ASR::make_Var_t(al, loc, t)); Vec args; args.reserve(al, 0); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 580434bc4f..1aab45b379 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -527,6 +527,20 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_get_argument(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("GetArgument"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython