Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Sep 26, 2021
1 parent 1f38801 commit 4e7784f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class LinkedParam : public ObjectRef {
*
* \code
* @T.prim_func
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: ty.int32) -> None:
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
.. code-block:: python
@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: ty.int32, n: ty.int32) -> None:
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32")
B = T.match_buffer(b, (m, n), "float32")
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
for (Type field : node->fields) {
fields.push_back(Print(field));
}
return Doc::Text("ty.Tuple[") << Doc::Concat(fields) << "]";
return Doc::Text("T.Tuple[") << Doc::Concat(fields) << "]";
}
}

Expand Down
11 changes: 7 additions & 4 deletions tests/python/unittest/test_meta_schedule_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@
PyBuilder,
)
from tvm.runtime import Module
from tvm.script import ty
from tvm.script import tir as T
from tvm.target import Target


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring


@script.tir
@script.ir_module
class MatmulModule:
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument
tir.func_attr({"global_symbol": "matmul", "tir.noalias": True})
A = tir.match_buffer(a, (1024, 1024), "float32")
Expand All @@ -52,8 +53,9 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@script.tir
@script.ir_module
class MatmulReluModule:
@T.prim_func
def matmul_relu( # pylint: disable=no-self-argument
a: T.handle, b: T.handle, d: T.handle
) -> None:
Expand All @@ -70,8 +72,9 @@ def matmul_relu( # pylint: disable=no-self-argument
D[vi, vj] = tir.max(C[vi, vj], 0.0)


@script.tir
@script.ir_module
class BatchMatmulModule:
@T.prim_func
def batch_matmul( # pylint: disable=no-self-argument
a: T.handle, b: T.handle, c: T.handle
) -> None:
Expand Down

0 comments on commit 4e7784f

Please sign in to comment.