Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relax] Argument type mismatch: expected R.Tensor, given R.Tuple #17223

Closed
Cookiee235 opened this issue Jul 31, 2024 · 1 comment
Closed
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

It seems the provided Relax IRs are valid, however, it crashed when was compiled using relax.build() unexpectedly.

Actual behavior

Traceback (most recent call last):
  File "test_simp.py", line 26, in <module>
    ex = relax.build(mod, target='llvm')  # crash here!
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm-lunder/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  33: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  32: tvm::transform::Pass::operator()(tvm::IRModule) const
  31: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  30: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: _ZN3tvm7runtime13PackedFuncObj
  26: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  25: tvm::relax::CallTIRMutator::Run()
  24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  22: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  21: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  20: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  19: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  18: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  17: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  16: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  15: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  14: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  13: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::ConstantNode const*)
  12: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  11: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  10: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
  8: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
  7: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  6: tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm-lunder/src/relax/ir/block_builder.cc", line 158
TVMError: Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module(check_well_formed=True)
class Module:
    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
        for i in range(16):
            B[i] = A[i] * T.float32(2)

    @R.function
    def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        cls = Module
        args: R.Tuple(R.Tensor((16,), dtype="float32")) = (A,)
        gv1 = R.call_tir(cls.multiply_by_two, (args,), out_sinfo=R.Tensor((16,), dtype="float32"))
        return gv1

mod = Module
mod.show()

mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')  # crash here!

cc @Lunderberg @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jul 31, 2024
@Lunderberg
Copy link
Contributor

I can run the test case and reproduce the error, but the error message seems correct for the test case. The first argument to Module.multiply_by_two is a tensor, but the first item of R.call_tir's argument tuple is a tuple. This could be caught earlier by the well-formed checker, when updated to validate the R.call_tir arguments.

(As a side-note, replacing (args,) with args would have the correct struct info, but wouldn't be an in-line relax Tuple as required by R.call_tir. See the discussion in #15916 for more detail on the requirement for an in-line tuple.)

@tqchen tqchen closed this as completed Feb 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

3 participants