Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dlight] Add fallback for low batch gemv with outer reduction #16701

Merged
merged 1 commit into from
Mar 12, 2024

Conversation

jinhongyii
Copy link
Contributor

Add fallback for low batch gemv with outer reduction

@kmn1024
Copy link

kmn1024 commented Mar 12, 2024

Thanks for the fix! I tried it and got past the previous error, but now the same compile command gives a new error:

[2024-03-12 09:57:31] INFO pipeline.py:43: Running TVM Dlight low-level optimizations
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 47, in <module>
    main()
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 24, in main
    cli.main(sys.argv[2:])
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/cli/compile.py", line 131, in main
    compile(
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 229, in compile
    _compile(args, model_config)
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 176, in _compile
    args.build_func(
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/support/auto_target.py", line 235, in build
    relax.build(
  File "/home/ubuntu/new-mlc/relax/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
  File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/pipeline.py", line 151, in _pipeline
    mod = seq(mod)
  File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/low_batch_specialization.py", line 28, in transform_module
    low_batch_mod = dl.ApplyDefaultSchedule(
  File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 64, in transform_module
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 80, in _apply_rules
    space = rule.apply(func, target, tunable)
  File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 273, in apply
    is_inner_reduction = normalize(sch, block_info)
  File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 200, in normalize
    sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops)
  File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/schedule.py", line 982, in reorder
    _ffi_api.ScheduleReorder(self, ordered_loops)  # type: ignore # pylint: disable=no-member
  File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  1: tvm::tir::TracedScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
        at /home/ubuntu/new-mlc/relax/src/tir/schedule/traced_schedule.cc:269
  0: tvm::tir::ConcreteScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
        at /home/ubuntu/new-mlc/relax/src/tir/schedule/concrete_schedule.cc:589
ScheduleError: An error occurred in the schedule primitive 'reorder'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        total_seq_len = T.int64()
        A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), total_seq_len), "float16")
        B = T.match_buffer(var_B, (T.int64(1), T.int64(32), total_seq_len, T.int64(64)), "float16")
        matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), T.int64(1), T.int64(64)), "float16")
        with T.block("root"):
            T.reads()
            T.writes()
            A_pad = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)), "float16")
            B_pad = T.alloc_buffer((T.int64(1), T.int64(32), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(64)), "float16")
            for ax0 in range(T.int64(32)):
                for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                    with T.block("A_pad"):
                        v0 = T.axis.spatial(T.int64(32), ax0)
                        v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
                        T.reads(A[T.int64(0), v0, T.int64(0), v1])
                        T.writes(A_pad[T.int64(0), v0, T.int64(0), v1])
                        A_pad[T.int64(0), v0, T.int64(0), v1] = T.if_then_else(v1 < total_seq_len, A[T.int64(0), v0, T.int64(0), v1], T.float16(0))
            for ax0 in range(T.int64(32)):
                for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                    for ax2 in range(T.int64(64)):
                        with T.block("B_pad"):
                            v0 = T.axis.spatial(T.int64(32), ax0)
                            v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
                            v2 = T.axis.spatial(T.int64(64), ax2)
                            T.reads(B[T.int64(0), v0, v1, v2])
                            T.writes(B_pad[T.int64(0), v0, v1, v2])
                            B_pad[T.int64(0), v0, v1, v2] = T.if_then_else(v1 < total_seq_len, B[T.int64(0), v0, v1, v2], T.float16(0))
            for ax0 in range(T.int64(32)):
                for ax1 in range(T.int64(64)):
                    # tir.For#0
                    for ax2 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for u in range(1):
                        ^^^^^^^^^^^^^^^^^^
                            with T.block("matmul"):
                            ^^^^^^^^^^^^^^^^^^^^^^^
                                v0 = T.axis.spatial(T.int64(32), ax0)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v1 = T.axis.spatial(T.int64(64), ax1)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v2 = T.axis.reduce((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax2)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.reads(A_pad[T.int64(0), v0, T.int64(0), v2], B_pad[T.int64(0), v0, v2, v1])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                with T.init():
                                ^^^^^^^^^^^^^^
                                    matmul[T.int64(0), v0, T.int64(0), v1] = T.float16(0)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                matmul[T.int64(0), v0, T.int64(0), v1] = matmul[T.int64(0), v0, T.int64(0), v1] + A_pad[T.int64(0), v0, T.int64(0), v2] * B_pad[T.int64(0), v0, v2, v1]
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: Loop tir.For#0 appears in the input array for multiple times.

@jinhongyii
Copy link
Contributor Author

I can't reproduce your error in my branch. I think the error you show has already been fixed in 5bbe1ab

Please check if you have this commit in your local branch

@tqchen tqchen merged commit fe340c9 into apache:main Mar 12, 2024
19 checks passed
@kmn1024
Copy link

kmn1024 commented Mar 13, 2024

Yes you are right. Thanks so much!

thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants