Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Script namespace changes #9115

Merged
merged 1 commit into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.4.0 \
synr==0.4.1 \
six \
tornado
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_sphinx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ pip3 install \
matplotlib \
sphinx \
sphinx_autodoc_annotation \
sphinx-gallery==0.4.0 \
sphinx-gallery==0.4.1 \
sphinx_rtd_theme
20 changes: 10 additions & 10 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ class LinkedParam : public ObjectRef {
* \note We can define a Meta TIR function with symbolic shape:
*
* \code
* @tvm.script.tir
* def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
* A = tir.match_buffer(a, (m, n), "float32")
* B = tir.match_buffer(b, (m, n), "float32")
* @T.prim_func
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
* with tir.block([m, n], "") as [vi, vj]:
* with T.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*
Expand All @@ -214,12 +214,12 @@ class LinkedParam : public ObjectRef {
* \endcode
*
* \code {.language-id}
* @tvm.script.tir
* def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
* A = tir.match_buffer(a, (16, 16), "float32")
* B = tir.match_buffer(b, (16, 16), "float32")
* @T.prim_func
* def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
* A = T.match_buffer(a, (16, 16), "float32")
* B = T.match_buffer(b, (16, 16), "float32")
*
* with tir.block([16, 16], "") as [vi, vj]:
* with T.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*/
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1068,17 +1068,17 @@ class MatchBufferRegion : public ObjectRef {
* \note Block's body is parameterized by iter vars.
* \code
*
* with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* tir.bind(v0, value0)
* tir.bind(v1, value1)
* with T.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* T.bind(v0, value0)
* T.bind(v1, value1)
* ...
* tir.reads([buffer0[start:end, ...], ...])
* tir.writes([buffer1[start:end, ...], ...])
* tir.where(predicate)
* buffer2 = tir.alloc_buffer(shape, dtype)
* buffer3 = tir.match_buffer(source_buffer[start:end, ...])
* tir.attr({attr_key: attr_value, ...})
* with tir.init():
* T.reads([buffer0[start:end, ...], ...])
* T.writes([buffer1[start:end, ...], ...])
* T.where(predicate)
* buffer2 = T.alloc_buffer(shape, dtype)
* buffer3 = T.match_buffer(source_buffer[start:end, ...])
* T.attr({attr_key: attr_value, ...})
* with T.init():
* // init body
* // body
*
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(16, 16)
* with T.block([]):
* B = T.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
* for j in range(0, 16):
Expand All @@ -404,8 +404,8 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(1, 16)
* with T.block([]):
* B = T.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
* for j in range(0, 16):
Expand Down
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.4.0"),
("synr", "==0.4.1"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,23 @@ def __str__(self):

def __repr__(self):
return self.astext()

def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript

Parameters
----------
tir_prefix : str
The tir namespace prefix

show_meta : bool
Whether to show meta information

Returns
-------
script : str
The TVM Script of the IRModule
"""
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore
4 changes: 3 additions & 1 deletion python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""

from .parser import from_source, create_module, asscript, tir, module
from . import tir

from .parser import ir_module, from_source
48 changes: 24 additions & 24 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice
from .tir.node import BufferSlice


class BlockInfo:
Expand All @@ -34,55 +34,55 @@ class BlockInfo:
----------
.. code-block:: python

@tvm.script.tir
def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
B = tir.match_buffer(b, (16, 16), "float32")
C = tir.match_buffer(a, (16, 16), "float32")
@T.prim_func
def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
B = T.match_buffer(b, (16, 16), "float32")
C = T.match_buffer(a, (16, 16), "float32")

for i, j, k in tir.grid(16, 16, 16):
with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as [vi, vj, vk]:
tir.bind(vi, i)
tir.bind(vj, j)
tir.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}
for i, j, k in T.grid(16, 16, 16):
with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]:
T.bind(vi, i)
T.bind(vj, j)
T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}

tir.where(True) # predicate of the block_realize
T.where(True) # predicate of the block_realize

tir.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block
tir.writes(C[0: 16, 0: 16]) # writes region of the block
tir.block_attr({"attr_key": "attr_value"}) # block annotations
T.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block
T.writes(C[0: 16, 0: 16]) # writes region of the block
T.block_attr({"attr_key": "attr_value"}) # block annotations

# alloc_buffers inside the block
CC = tir.alloc_buffer((1, 1), dtype="float32")
CC = T.alloc_buffer((1, 1), dtype="float32")

# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer(C[vi, vj], ())
D = T.match_buffer(C[vi, vj], ())

# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)
with T.init():
C[vi, vj] = T.float32(0)

# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.reads statements in the block signature, None for not-visited"""
list of T.reads statements in the block signature, None for not-visited"""
writes: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.writes statements in the block signature, None for not-visited"""
list of T.writes statements in the block signature, None for not-visited"""
annotations: Optional[Mapping[str, Object]] = None
"""Optional[Mapping[str, Object]]:
list of tir.block_attr statements in the block signature, None for not-visited"""
list of T.block_attr statements in the block signature, None for not-visited"""
predicate: Optional[PrimExpr] = None
"""Optional[PrimExpr]: block realize predicate, None for not-visited"""
init: Optional[Stmt] = None
Expand Down
Loading