Skip to content

Commit

Permalink
Support rich type annotation (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Jun 25, 2022
1 parent 92c56d6 commit 342c85a
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 137 deletions.
52 changes: 26 additions & 26 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,56 +77,56 @@
from . import _ffi_api


def boolean(expr):
return _ffi_api.PrimType("bool", expr)
def int8(expr=None):
return _ffi_api.Int8(expr)


def int8(expr):
return _ffi_api.PrimType("int8", expr)
def int16(expr=None):
return _ffi_api.Int16(expr)


def int16(expr):
return _ffi_api.PrimType("int16", expr)
def int32(expr=None):
return _ffi_api.Int32(expr)


def int32(expr):
return _ffi_api.PrimType("int32", expr)
def int64(expr=None):
return _ffi_api.Int64(expr)


def int64(expr):
return _ffi_api.PrimType("int64", expr)
def uint8(expr=None):
return _ffi_api.UInt8(expr)


def uint8(expr):
return _ffi_api.PrimType("uint8", expr)
def uint16(expr=None):
return _ffi_api.UInt16(expr)


def uint16(expr):
return _ffi_api.PrimType("uint16", expr)
def uint32(expr=None):
return _ffi_api.UInt32(expr)


def uint32(expr):
return _ffi_api.PrimType("uint32", expr)
def uint64(expr=None):
return _ffi_api.UInt64(expr)


def uint64(expr):
return _ffi_api.PrimType("uint64", expr)
def float8(expr=None):
return _ffi_api.Float8(expr)


def float8(expr):
return _ffi_api.PrimType("float8", expr)
def float16(expr=None):
return _ffi_api.Float16(expr)


def float16(expr):
return _ffi_api.PrimType("float16", expr)
def float32(expr=None):
return _ffi_api.Float32(expr)


def float32(expr):
return _ffi_api.PrimType("float32", expr)
def float64(expr=None):
return _ffi_api.Float64(expr)


def float64(expr):
return _ffi_api.PrimType("float64", expr)
def boolean(expr=None):
return _ffi_api.Boolean(expr)


def handle():
Expand Down
26 changes: 17 additions & 9 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@
from . import _ffi_api


def Buffer( # pylint: disable=invalid-name
shape,
dtype="float32",
name="buffer",
storage_scope="",
) -> tir.Buffer:
return _ffi_api.Buffer(
shape, dtype, name, storage_scope
) # pylint: disable=no-member # type: ignore
class BufferProxy:
def __call__(
self,
shape,
dtype="float32",
*,
storage_scope="",
) -> tir.Buffer:
return _ffi_api.Buffer( # pylint: disable=no-member # type: ignore
shape, dtype, "", storage_scope
)

def __getitem__(self, keys) -> tir.Buffer:
return self(*keys) # pylint: disable=no-member # type: ignore


Buffer = BufferProxy()
70 changes: 0 additions & 70 deletions python/tvm/script/parse/doc_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,39 +83,6 @@ def __init__(
self.returns = returns


class AsyncFunctionDef(stmt):
_FIELDS = [
"name",
"args",
"body",
"decorator_list",
"returns",
"lineno",
"col_offset",
"end_lineno",
"end_col_offset",
]

def __init__(
self,
name,
args,
body,
decorator_list,
returns,
lineno,
col_offset,
end_lineno,
end_col_offset,
):
super().__init__(lineno, col_offset, end_lineno, end_col_offset)
self.name = name
self.args = args
self.body = body
self.decorator_list = decorator_list
self.returns = returns


class ClassDef(stmt):
_FIELDS = [
"name",
Expand Down Expand Up @@ -249,26 +216,6 @@ def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, e
self.orelse = orelse


class AsyncFor(stmt):
_FIELDS = [
"target",
"iter",
"body",
"orelse",
"lineno",
"col_offset",
"end_lineno",
"end_col_offset",
]

def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset):
super().__init__(lineno, col_offset, end_lineno, end_col_offset)
self.target = target
self.iter = iter
self.body = body
self.orelse = orelse


class While(stmt):
_FIELDS = [
"test",
Expand Down Expand Up @@ -314,15 +261,6 @@ def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset):
self.body = body


class AsyncWith(stmt):
_FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"]

def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset):
super().__init__(lineno, col_offset, end_lineno, end_col_offset)
self.items = items
self.body = body


class Raise(stmt):
_FIELDS = ["exc", "cause", "lineno", "col_offset", "end_lineno", "end_col_offset"]

Expand Down Expand Up @@ -595,14 +533,6 @@ def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offs
self.generators = generators


class Await(expr):
_FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]

def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset):
super().__init__(lineno, col_offset, end_lineno, end_col_offset)
self.value = value


class Yield(expr):
_FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"]

Expand Down
42 changes: 34 additions & 8 deletions python/tvm/script/parse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
from .utils import deferred
from .var_table import VarTable

DEFAULT_VISIT = {
"Interactive",
"Module",
"Expression",
"Pass",
}


def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
for token in [self.dispatch_tokens[-1], "default"]:
Expand Down Expand Up @@ -74,27 +81,46 @@ def eval_assign(
def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}")

def visit(self, node: doc.AST) -> None:
if isinstance(node, (list, tuple)):
for item in node:
self.visit(item)
return
if not isinstance(node, doc.AST):
return
name = node.__class__.__name__.split(".")[-1]
if name in DEFAULT_VISIT:
func = self.generic_visit
else:
func = getattr(self, "visit_" + name, None)
if func is None:
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
func(node)

def visit_body(self, node: List[doc.stmt]) -> Any:
for stmt in node:
self.visit(stmt)

def visit_tvm_annotation(self, node: doc.expr) -> Any:
return _dispatch(self, "tvm_annotation")(self, node)

def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
_handle_function(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
_handle_class(self, node)

def visit_body(self, node: List[doc.stmt]) -> Any:
for stmt in node:
self.visit(stmt)

def visit_arguments(self, node: doc.arguments) -> Any:
_dispatch(self, "arguments")(self, node)
return _dispatch(self, "arguments")(self, node)

def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name
_dispatch(self, "For")(self, node)
return _dispatch(self, "For")(self, node)

def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name
_dispatch(self, "With")(self, node)
return _dispatch(self, "With")(self, node)

def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
_dispatch(self, "Assign")(self, node)
return _dispatch(self, "Assign")(self, node)


def _handle_function(self: Parser, node: doc.FunctionDef) -> None:
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/script/parse/tir/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,17 @@ def visit_arguments(self: Parser, node: doc.arguments) -> None:
# - kwarg: arg | None
# - defaults: list[expr]
# - posonlyargs: list[arg]
arg: doc.arg
for arg in node.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param = T.arg(arg.arg, self.eval_expr(arg.annotation))
param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
self.var_table.add(arg.arg, param)


@dispatch.register(token="tir", type_name="tvm_annotation")
def visit_tvm_annotation(self: Parser, node: doc.expr):
annotation = self.eval_expr(node)
if callable(annotation):
annotation = annotation()
return annotation
23 changes: 17 additions & 6 deletions src/script/builder/tir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ namespace script {
namespace builder {
namespace tir {

TVM_REGISTER_GLOBAL("script.builder.tir.Int8").set_body_typed(Int8);
TVM_REGISTER_GLOBAL("script.builder.tir.Int16").set_body_typed(Int16);
TVM_REGISTER_GLOBAL("script.builder.tir.Int32").set_body_typed(Int32);
TVM_REGISTER_GLOBAL("script.builder.tir.Int64").set_body_typed(Int64);
TVM_REGISTER_GLOBAL("script.builder.tir.UInt8").set_body_typed(UInt8);
TVM_REGISTER_GLOBAL("script.builder.tir.UInt16").set_body_typed(UInt16);
TVM_REGISTER_GLOBAL("script.builder.tir.UInt32").set_body_typed(UInt32);
TVM_REGISTER_GLOBAL("script.builder.tir.UInt64").set_body_typed(UInt64);
TVM_REGISTER_GLOBAL("script.builder.tir.Float8").set_body_typed(Float8);
TVM_REGISTER_GLOBAL("script.builder.tir.Float16").set_body_typed(Float16);
TVM_REGISTER_GLOBAL("script.builder.tir.Float32").set_body_typed(Float32);
TVM_REGISTER_GLOBAL("script.builder.tir.Float64").set_body_typed(Float64);
TVM_REGISTER_GLOBAL("script.builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(PrimType);
TVM_REGISTER_GLOBAL("script.builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b) {
return tvm::min(a, b);
});
TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b) {
return tvm::max(a, b);
});
TVM_REGISTER_GLOBAL("script.builder.tir.min")
.set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); });
TVM_REGISTER_GLOBAL("script.builder.tir.max")
.set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); });

} // namespace tir
} // namespace builder
Expand Down
Loading

0 comments on commit 342c85a

Please sign in to comment.