diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c index 9083f7b5f48b..6018d40dd300 100644 --- a/apps/bundle_deploy/bundle.c +++ b/apps/bundle_deploy/bundle.c @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #ifdef ENABLE_TVM_ABORT_BACKTRACE #include "backtrace.h" @@ -64,8 +64,8 @@ TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, dev.device_id = device_id; // declare pointers - TVM_CCALL(MemoryManagerCreate(&g_memory_manager, g_crt_memory, sizeof(g_crt_memory), - CRT_MEMORY_PAGE_SIZE_LOG2)); + TVM_CCALL(PageMemoryManagerCreate(&g_memory_manager, g_crt_memory, sizeof(g_crt_memory), + CRT_MEMORY_PAGE_SIZE_LOG2)); TVM_CCALL(TVMInitializeRuntime()); TVMPackedFunc pf; TVMArgs args = TVMArgs_Create(NULL, NULL, 0); diff --git a/apps/bundle_deploy/bundle_static.c b/apps/bundle_deploy/bundle_static.c index 62e63d6b4fe2..18a7b2bbb0ff 100644 --- a/apps/bundle_deploy/bundle_static.c +++ b/apps/bundle_deploy/bundle_static.c @@ -22,8 +22,8 @@ #include #include #include -#include #include +#include #include #ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE @@ -64,8 +64,8 @@ TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, dev.device_id = device_id; // get pointers - TVM_CCALL(MemoryManagerCreate(&g_memory_manager, g_crt_memory, sizeof(g_crt_memory), - CRT_MEMORY_PAGE_SIZE_LOG2)); + TVM_CCALL(PageMemoryManagerCreate(&g_memory_manager, g_crt_memory, sizeof(g_crt_memory), + CRT_MEMORY_PAGE_SIZE_LOG2)); TVM_CCALL(TVMInitializeRuntime()); TVMPackedFunc pf; TVMArgs args = TVMArgs_Create(NULL, NULL, 0); diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index fe6baf81c3e5..620f7552cef6 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -44,6 +44,7 @@ if(USE_MICRO) "src/runtime/crt/include *.h -> include" "src/runtime/crt/common *.c -> src/runtime/crt/common" "src/runtime/crt/graph_executor *.c -> src/runtime/crt/graph_executor" + "src/runtime/crt/aot_executor *.c -> src/runtime/crt/aot_executor" "src/runtime/crt/graph_executor_module *.c -> src/runtime/crt/graph_executor_module" "src/runtime/crt/host crt_config.h -> template/host" "src/runtime/crt/host *.cc -> template/host" @@ -97,7 +98,7 @@ if(USE_MICRO) set(make_quiet ) endif(${VERBOSE}) - list(APPEND crt_libraries memory graph_executor utvm_rpc_server utvm_rpc_common common) # NOTE: listed in link order. + list(APPEND crt_libraries memory graph_executor aot_executor utvm_rpc_server utvm_rpc_common common) # NOTE: listed in link order. foreach(crt_lib_name IN LISTS crt_libraries) list(APPEND crt_library_paths "host_standalone_crt/lib${crt_lib_name}.a") endforeach() diff --git a/include/tvm/runtime/crt/memory.h b/include/tvm/runtime/crt/page_allocator.h similarity index 87% rename from include/tvm/runtime/crt/memory.h rename to include/tvm/runtime/crt/page_allocator.h index c830116528e0..7a5de169c72e 100644 --- a/include/tvm/runtime/crt/memory.h +++ b/include/tvm/runtime/crt/page_allocator.h @@ -18,12 +18,12 @@ */ /*! - * \file tvm/runtime/crt/memory.h + * \file tvm/runtime/crt/page_allocator.h * \brief An implementation of a dynamic memory allocator for microcontrollers. */ -#ifndef TVM_RUNTIME_CRT_MEMORY_H_ -#define TVM_RUNTIME_CRT_MEMORY_H_ +#ifndef TVM_RUNTIME_CRT_PAGE_ALLOCATOR_H_ +#define TVM_RUNTIME_CRT_PAGE_ALLOCATOR_H_ #ifdef __cplusplus extern "C" { @@ -72,11 +72,11 @@ struct MemoryManagerInterface { * \param page_size_bytes_log2 log2 of the page size, in bytes. * \return kTvmErrorNoError on success. */ -tvm_crt_error_t MemoryManagerCreate(MemoryManagerInterface** manager, uint8_t* memory_pool, - size_t memory_pool_size_bytes, size_t page_size_bytes_log2); +tvm_crt_error_t PageMemoryManagerCreate(MemoryManagerInterface** manager, uint8_t* memory_pool, + size_t memory_pool_size_bytes, size_t page_size_bytes_log2); #ifdef __cplusplus } // extern "C" #endif -#endif // TVM_RUNTIME_CRT_MEMORY_H_ +#endif // TVM_RUNTIME_CRT_PAGE_ALLOCATOR_H_ diff --git a/include/tvm/runtime/crt/stack_allocator.h b/include/tvm/runtime/crt/stack_allocator.h new file mode 100644 index 000000000000..daa403cb2764 --- /dev/null +++ b/include/tvm/runtime/crt/stack_allocator.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE +#ifndef TVM_RUNTIME_CRT_STACK_ALLOCATOR_H_ +#define TVM_RUNTIME_CRT_STACK_ALLOCATOR_H_ +#include +#include + +#include "crt_config.h" +#include "error_codes.h" + +#define STACK_ALLOCATOR_TAG 0xabcd1234 +#define STACK_ALLOCATOR_TAG_SIZE_BYTES 4 + +/*! Memory alignment for allocator */ + +#ifndef TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES +#define TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES 16 +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + uint8_t* next_alloc; // Pointer to the next block of TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES + uint8_t* workspace; // Pointer to start of the workspace + size_t workspace_size; // Total number of bytes in the workspace +} tvm_workspace_t; + +tvm_crt_error_t StackMemoryManager_Init(tvm_workspace_t* tvm_runtime_workspace, + uint8_t* g_aot_memory, size_t workspace_size); + +tvm_crt_error_t StackMemoryManager_Allocate(tvm_workspace_t* tvm_runtime_workspace, int32_t nbytes, + void**); + +tvm_crt_error_t StackMemoryManager_Free(tvm_workspace_t* tvm_runtime_workspace, void* ptr); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_STACK_ALLOCATOR_H_ diff --git a/include/tvm/runtime/executor_info.h b/include/tvm/runtime/executor_info.h new file mode 100644 index 000000000000..5b3572120c9a --- /dev/null +++ b/include/tvm/runtime/executor_info.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file executor_info.h + * \brief Executor information + */ +#ifndef TVM_RUNTIME_EXECUTOR_INFO_H_ +#define TVM_RUNTIME_EXECUTOR_INFO_H_ + +namespace tvm { +namespace runtime { + +/*! \brief Value used to indicate the graph executor. */ +static constexpr const char* kTvmExecutorGraph = "graph"; + +/*! \brief Value used to indicate the aot executor. */ +static constexpr const char* kTvmExecutorAot = "aot"; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_EXECUTOR_INFO_H_ diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 04a5cf8bf25d..689fe6fa53fc 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -230,6 +230,8 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; +/*! \brief The main AOT executor function */ +constexpr const char* tvm_run_func_prefix = "tvm__run_func"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d8248d4e1a87..aab5d662d49c 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -346,6 +346,18 @@ TVM_DLL const Op& tvm_stack_make_array(); */ TVM_DLL const Op& tvm_call_packed(); +/*! + * \brief See pesudo code + * + * return_type tvm_call_packed(fname, TVMValue* args) { + * int ret_code; + * TVMValue ret_value; + * (*fname)(args, type_code_of(args), len(args), &ret_value, &ret_code); + * return cast(return_type, ret_value.v_return_type); + * } + */ +TVM_DLL const Op& tvm_call_cpacked(); + /*! * \brief See pesudo code * @@ -392,6 +404,21 @@ TVM_DLL const Op& tvm_thread_context(); */ TVM_DLL const Op& tvm_call_packed_lowered(); +/*! + * \brief Lowered version of call c-packed, the space of value and + * type codes are explicitly allocated. + * + * int tvm_call_packed_lowered(fname, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * fname(TVMArgs(value_stack[begin:end], tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +TVM_DLL const Op& tvm_call_cpacked_lowered(); + /*! * \brief Lowered version of trace intrinsic, the space of value and * type codes are explicitly allocated. The return value is the diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index a26a47c788fe..e48125f0f619 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -52,7 +52,7 @@ import tvm.contrib.cc from tvm import relay from tvm.contrib import utils -from tvm.relay.backend.graph_executor_factory import GraphExecutorFactoryModule +from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule from .common import TVMCException @@ -220,7 +220,7 @@ def export_package( self.lib_path = path_lib with open(temp.relpath(graph_name), "w") as graph_file: - graph_file.write(executor_factory.get_json()) + graph_file.write(executor_factory.get_graph_json()) with open(temp.relpath(param_name), "wb") as params_file: params_file.write(relay.save_param_dict(executor_factory.get_params())) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 6768e03f4473..4fd85ea38d98 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -24,7 +24,7 @@ import tarfile from ..contrib import utils -from ..relay.backend import graph_executor_factory +from ..relay.backend import executor_factory from ..relay import param_dict @@ -117,7 +117,7 @@ def _build_memory_map(graph_json): return memory_map -def export_model_library_format(mod: graph_executor_factory.GraphExecutorFactoryModule, file_name): +def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name): """Export the build artifact in Model Library Format. This function creates a .tar archive containing the build artifacts in a standardized @@ -126,20 +126,25 @@ def export_model_library_format(mod: graph_executor_factory.GraphExecutorFactory Parameters ---------- - mod : tvm.relay.backend.graph_executor_factory.GraphExecutorFactoryModule + mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule The return value of tvm.relay.build, which will be exported into Model Library Format. file_name : str Path to the .tar archive to generate. """ tempdir = utils.tempdir() + is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) + memory_map = [] if is_aot else _build_memory_map(mod.get_executor_config()) + runtime = ["aot"] if is_aot else ["graph"] + metadata = { "version": 1, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), - "memory": _build_memory_map(mod.graph_json), + "memory": memory_map, "target": {int(k): str(v) for k, v in mod.target.items()}, - "runtimes": ["graph"], + "runtimes": runtime, } + with open(tempdir.relpath("metadata.json"), "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) @@ -156,10 +161,11 @@ def export_model_library_format(mod: graph_executor_factory.GraphExecutorFactory with open(tempdir.relpath("relay.txt"), "w") as f: f.write(str(mod.ir_mod)) - graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) - os.makedirs(graph_config_dir_path) - with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: - f.write(mod.graph_json) + if not is_aot: + graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) + os.makedirs(graph_config_dir_path) + with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + f.write(mod.get_executor_config()) with tarfile.open(file_name, "w") as tar_f: diff --git a/python/tvm/relay/backend/graph_executor_factory.py b/python/tvm/relay/backend/executor_factory.py similarity index 62% rename from python/tvm/relay/backend/graph_executor_factory.py rename to python/tvm/relay/backend/executor_factory.py index d6959d22e5c8..f81d8f9f1c15 100644 --- a/python/tvm/relay/backend/graph_executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -14,21 +14,100 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Graph executor factory.""" +"""Executor factory modules.""" +from abc import abstractmethod import warnings + from ..._ffi.base import string_types from ..._ffi.registry import get_global_func from ...runtime import ndarray -class GraphExecutorFactoryModule: +class ExecutorFactoryModule: + """Common interface for executor factory modules + This class describes the common API of different + factory modules + """ + + @abstractmethod + def get_executor_config(self): + """ Return the internal configuration the executor uses to execute the network """ + raise NotImplementedError + + @abstractmethod + def get_params(self): + """Return the compiled parameters.""" + raise NotImplementedError + + @abstractmethod + def get_lib(self): + """ Return the generated library""" + raise NotImplementedError + + def __getitem__(self, item): + return self.module.__getitem__(item) + + def __iter__(self): + warnings.warn( + "legacy graph executor behavior of producing json / lib / params will be " + "removed in the next release." + " Please see documents of tvm.contrib.graph_executor.GraphModule for the " + " new recommended usage.", + DeprecationWarning, + 2, + ) + return self + + def __next__(self): + if self.iter_cnt > 2: + raise StopIteration + + objs = [self.get_executor_config(), self.lib, self.params] + obj = objs[self.iter_cnt] + self.iter_cnt += 1 + return obj + + +class AOTExecutorFactoryModule(ExecutorFactoryModule): + """AOT executor factory module. + + Attributes + ---------- + target : tvm.Target + The Target used to build this module. + libmod : tvm.Module + The module of the corresponding function + libmod_name: str + The name of module + params : dict of str to NDArray + The parameters of module + """ + + def __init__(self, ir_mod, target, libmod, libmod_name, params): + self.ir_mod = ir_mod + self.target = target + self.lib = libmod + self.libmod_name = libmod_name + self.params = params + self.iter_cnt = 0 + + def get_params(self): + return self.params + + def get_executor_config(self): + return None + + def get_lib(self): + return self.lib + + +class GraphExecutorFactoryModule(ExecutorFactoryModule): """Graph executor factory module. This is a module of graph executor factory - Parameters + Attributes ---------- - graph_json_str : str - The graph to be deployed in json format output by graph compiler. + graph_json_str : the json graph to be deployed in json format output by graph compiler. The graph can contain operator(tvm_op) that points to the name of PackedFunc in the libmod. target : tvm.Target @@ -48,6 +127,7 @@ def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params): for k, v in params.items(): args.append(k) args.append(ndarray.array(v)) + self.ir_mod = ir_mod self.target = target self.module = fcreate(graph_json_str, libmod, libmod_name, *args) @@ -60,37 +140,14 @@ def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params): def export_library(self, file_name, fcompile=None, addons=None, **kwargs): return self.module.export_library(file_name, fcompile, addons, **kwargs) - # Sometimes we want to get params explicitly. - # For example, we want to save its params value to - # an independent file. def get_params(self): return self.params - def get_json(self): + def get_graph_json(self): + return self.graph_json + + def get_executor_config(self): return self.graph_json def get_lib(self): return self.lib - - def __getitem__(self, item): - return self.module.__getitem__(item) - - def __iter__(self): - warnings.warn( - "legacy graph executor behavior of producing json / lib / params will be " - "removed in the next release." - " Please see documents of tvm.contrib.graph_executor.GraphModule for the " - " new recommended usage.", - DeprecationWarning, - 2, - ) - return self - - def __next__(self): - if self.iter_cnt > 2: - raise StopIteration - - objs = [self.graph_json, self.lib, self.params] - obj = objs[self.iter_cnt] - self.iter_cnt += 1 - return obj diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 6eb684e570d9..2d8c8207c930 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -34,7 +34,7 @@ from . import expr as _expr from . import function as _function from .transform import InferType -from .backend import graph_executor_factory as _graph_executor_factory +from .backend import executor_factory as _executor_factory from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -84,7 +84,7 @@ def __init__(self): self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] - def build(self, mod, target=None, target_host=None, params=None): + def build(self, mod, target=None, target_host=None, params=None, executor="graph"): """ Parameters ---------- @@ -109,6 +109,11 @@ def build(self, mod, target=None, target_host=None, params=None): Input parameters to the graph that do not change during inference time. Used for constant folding. + executor: str[Optional] + The type of executor to be used in order to run the model: + - If "graph" is specified, then the graph_executor will be used + - If "aot" is specified, then the aot_executor will be used + Returns ------- graph_json : str @@ -139,15 +144,15 @@ def build(self, mod, target=None, target_host=None, params=None): old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler - self._build(mod, target, target_host) + self._build(mod, target, target_host, executor) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts - graph_json = self.get_json() mod = self.get_module() params = self.get_params() + executor_config = self.get_graph_json() if executor == "graph" else None - return graph_json, mod, params + return executor_config, mod, params def optimize(self, mod, target=None, params=None): """ @@ -187,7 +192,7 @@ def optimize(self, mod, target=None, params=None): def _set_params(self, params): self._set_params_func(_convert_param_map(params)) - def get_json(self): + def get_graph_json(self): """Return the json file of the built program.""" return self._get_graph_json() @@ -219,6 +224,33 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo return build(mod, target, params=params, mod_name=mod_name).module +def get_executor_from_target(target, target_host): + """Helper function to extract the executor parameter from the target + + Parameters + ---------- + target : Dict of targets for heterogeneous compilation + + target_host : Host compilation target + + Returns + ------- + executor : str + A string representing the executor type + """ + + # Default executor is graph + executor = "graph" + cpu_device_type = 1 + if target_host: + executor = target_host.attrs.get("executor", "graph") + else: + for device_type in target: + if device_type == cpu_device_type: + executor = target[device_type].attrs.get("executor", "graph") + return executor + + def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long @@ -251,7 +283,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" Returns ------- - factory_module : tvm.relay.backend.graph_executor_factory.GraphExecutorFactoryModule + factory_module : tvm.relay.backend.executor_factory.ExecutorFactoryModule The runtime factory for the TVM graph executor. """ # pylint: enable=line-too-long @@ -278,6 +310,9 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" target, target_host, target_is_dict_key=False ) + # Retrieve the executor from the target + executor = get_executor_from_target(target, target_host) + # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): @@ -287,10 +322,21 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" with tophub_context: bld_mod = BuildModule() - graph_json, runtime_mod, params = bld_mod.build(mod=ir_mod, target=target, params=params) - executor_factory = _graph_executor_factory.GraphExecutorFactoryModule( - ir_mod, target, graph_json, runtime_mod, mod_name, params + executor_config, runtime_mod, params = bld_mod.build( + mod=ir_mod, target=target, params=params, executor=executor ) + + if executor == "aot": + executor_factory = _executor_factory.AOTExecutorFactoryModule( + ir_mod, target, runtime_mod, mod_name, params + ) + elif executor == "graph": + executor_factory = _executor_factory.GraphExecutorFactoryModule( + ir_mod, target, executor_config, runtime_mod, mod_name, params + ) + else: + assert False, "Executor " + executor + " not supported" + return executor_factory diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc new file mode 100644 index 000000000000..1939e05e2075 --- /dev/null +++ b/src/relay/backend/aot_executor_codegen.cc @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/graph_codegen.cc + * \brief Graph runtime codegen + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "compile_engine.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +using IntegerArray = Array; +using TargetsMap = std::unordered_map; + +class AotReturnSidVisitor : public ExprVisitor { + public: + explicit AotReturnSidVisitor(Map> storage_device_map) + : storage_device_map_{storage_device_map}, return_sid_{-1} {} + + IntegerArray FindReturnSid(Function func) { + VisitExpr(func->body); + return return_sid_; + } + + protected: + void AssignReturnSid(Expr e) { + auto iter = storage_device_map_.find(e); + if (iter != storage_device_map_.end()) { + return_sid_ = (*iter).second[0]; + } + } + + void VisitExpr_(const ConstantNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef(cn)); + } + + void VisitExpr_(const VarNode* vn) override { + ExprVisitor::VisitExpr_(vn); + AssignReturnSid(GetRef(vn)); + } + + void VisitExpr_(const CallNode* cn) override { + ExprVisitor::VisitExpr_(cn); + AssignReturnSid(GetRef(cn)); + } + + void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); } + + void VisitExpr_(const TupleNode* tn) override { + ExprVisitor::VisitExpr_(tn); + AssignReturnSid(GetRef(tn)); + } + + private: + Map> storage_device_map_; + IntegerArray return_sid_; +}; + +/*! \brief Code generator for AOT executor */ +class AOTExecutorCodegen : public ExprVisitor { + protected: + /*! + * \brief Utility function to allocate a DLTensor or TVMValue + * \param type the type of allocation + * \param num the number of variable to allocate on the stack + * \return PrimExpr representing the allocated object + */ + PrimExpr StackAlloca(std::string type, size_t num) { + Array args = {tir::StringImm(type), ConstInt32(num)}; + return tir::Call(DataType::Handle(), tir::builtin::tvm_stack_alloca(), args); + } + + /*! + * \brief Utility function to convert a concrete integer to a PrimExpr. + * \param num the number to convert + * \return PrimExpr representing num + */ + inline PrimExpr ConstInt32(size_t num) { + ICHECK_LE(num, std::numeric_limits::max()); + return tir::make_const(DataType::Int(32), static_cast(num)); + } + + /*! + * \brief Return a vector of variables that represents the sids for the given Relay Expr + */ + std::vector PackSid(Expr expr) { + Array sids = storage_device_map_[expr]; + std::vector sid_vars; + + // Note that an expression can have multiple sids associated with it + // e.g., returning multiple values from a function + for (const auto& sid : sids[0]) { + // Determine if an sid is an output buffer + int sid_int = static_cast((sid.as())->value); + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + sid_vars.push_back(main_signature_[input_vars_.size() + output_index]); + continue; + } + // Pack the sid inside the TVMValue + auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle()); + auto sid_value = sids_table_[sid]; + tvm::PrimExpr set_tensor = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {sid_array, 0, tir::builtin::kArrData, sid_value}); + stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor))); + sid_vars.push_back(sid_array); + } + return sid_vars; + } + + /*! + * \brief Utility function to return a parameter associated with an expression + * \param expr Relay Expression assicated with the parameter + * \return Variable that represents the DLTensor associated with the parameters + */ + tir::Var PackParam(Expr expr) { + // TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the + // builtin::ret is not supported yet in the c target. Once return is supported we can use + // tvm_call_packed_lowered(). + int param_sid = param_storage_ids_[params_by_expr_[expr]]; + auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param); + auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle()); + + // Compose the lookup_call using a local stack + Array lookup_call; + auto param_var = te::Var(MakeString("param_", param_sid, "_value"), DataType::Handle()); + auto ret_var = te::Var("ret_value", DataType::Handle()); + auto ret_code = te::Var("ret_value", DataType::Handle()); + + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)}))); + lookup_call.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(), + {lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0}))); + auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {ret_var, 0, tir::builtin::kTVMValueContent}); + + // Set the param to the value returned by lookup_call + tvm::PrimExpr set_param_array = + tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(), + {param_array, 0, tir::builtin::kArrData, ret_var_handle}); + lookup_call.push_back(tir::Evaluate(set_param_array)); + + tir::Stmt lookup_body = tir::SeqStmt(lookup_call); + + // Allocate the DLTensors on the stack + lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body); + lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body); + stmts_.push_back(lookup_body); + return param_array; + } + + /*! + * brief Given an expression return the variable(s) associated with that expression + */ + std::vector FindExpr(Expr arg) { + auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); + if (input_iter != input_vars_.end()) { + // Input variable + int main_index = std::distance(input_vars_.begin(), input_iter); + return {main_signature_[main_index]}; + } else if (params_by_expr_.find(arg) != params_by_expr_.end()) { + // Parameter of the network + return {PackParam(arg)}; + } else { + // Storage identifier (i.e., intermediate memory) + return PackSid(arg); + } + } + + /*! + * brief Call a function with a given name + */ + void CreateFuncCall(Call call, std::string func_name) { + tvm::Array args{tvm::tir::StringImm(func_name)}; + std::vector create_func_call_stmts; + + // Pack the inputs + for (Expr arg : call->args) { + auto var_arg = FindExpr(arg); + args.push_back(var_arg[0]); + } + + auto ret_expr = Downcast(call); + + // Pack the return(s) value. A call node can produce multiple outputs + for (const auto& var : PackSid(ret_expr)) { + args.push_back(var); + } + + // Use tvm_call_packed to execute the function + create_func_call_stmts.push_back(tir::Evaluate( + tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args))); + tir::Stmt body = tir::SeqStmt(create_func_call_stmts); + stmts_.push_back(body); + } + + /*! + * brief Copy a variable to the output. This function is mainly used in edge cases + * when we want to return an input or a parameter. + * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a + * copy-on-write fashion. + */ + void CopyToOutput(te::Var out, te::Var in, size_t size) { + auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {in, 0, tir::builtin::kArrData}); + + // Define intermediate DLTensor to load/store the data + auto tmp0 = te::Var("tmp0", DataType::Handle()); + auto tmp1 = te::Var("tmp1", DataType::Handle()); + te::Var loop_idx("i", DataType::Int(32)); + auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); + auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), + {out, 0, tir::builtin::kArrData}); + + // Copy the variable from the input to the output + tir::Stmt copy = tir::For( + loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, + tir::Store(tmp1, tir::Let(tmp0, retval_get, retval_i), loop_idx, tir::const_true())); + stmts_.push_back(tir::LetStmt(tmp1, tostore, copy)); + } + + /*! + * Utility function to string together different arguments + */ + template + std::string MakeString(Args const&... args) { + std::ostringstream ss; + using List = int[]; + (void)List{0, ((void)(ss << args), 0)...}; + + return ss.str(); + } + + void VisitExpr_(const CallNode* op) override { + // Descend the call tree + for (auto arg : op->args) { + VisitExpr(arg); + } + + Expr expr = GetRef(op); + Function func; + if (op->op.as()) { + LOG(FATAL) << "Operators should be transformed away; try applying" + << "the fuse_ops transformation to the expression."; + } else if (op->op.as()) { + LOG(FATAL) << "Not implemented"; + } else if (op->op.as()) { + func = GetRef(op->op.as()); + } else { + LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); + } + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + LOG(FATAL) << "TVM only support calls to primitive functions " + << "(i.e functions composed of fusable operator invocations)"; + } + + auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); + auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); + Target target; + // Handle external function + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = (*pf0)(func, target); + CachedFunc ext_func = (*pf1)(compile_engine_, key); + ICHECK(ext_func.defined()) << "External function is not defined."; + UpdateConstants(func, ¶ms_); + + // Generate the TIR function call + CreateFuncCall(GetRef(op), ext_func->func_name); + return; + } + + ICHECK_GE(storage_device_map_.count(expr), 0); + auto& device_type = storage_device_map_[expr][1]; + auto call_dev_type = device_type[0]->value; + // Normal Relay Function + if (targets_.size() == 1) { + // homogeneous execution. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } + if (targets_.count(call_dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + target = targets_[call_dev_type]; + } + CCacheKey key = (*pf0)(func, target); + CachedFunc lowered_func = (*pf1)(compile_engine_, key); + if (!lowered_funcs_.count(target->str())) { + lowered_funcs_[target->str()] = IRModule(Map({})); + } + lowered_funcs_[target->str()]->Update(lowered_func->funcs); + + // Generate the TIR function call + CreateFuncCall(GetRef(op), lowered_func->func_name); + } + + void VisitExpr_(const VarNode* op) override { + Expr expr = GetRef(op); + + // If the Var node is an output node we need to copy the content of the variable to the output + // It's safe to check the SID here because Var StorageToken are never reallocated + Array sids = storage_device_map_[expr]; + + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast((sids[0][0].as())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + auto var_expr = FindExpr(expr); + CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]); + } + } + + void VisitExpr_(const ConstantNode* op) override { + Expr expr = GetRef(op); + size_t index = params_.size(); + std::string name = "p" + std::to_string(index); + + param_storage_ids_[name] = storage_device_map_[expr][0][0]->value; + params_[name] = op->data; + params_by_expr_.Set(expr, name); + + // If the Constant node is an output node we need to copy the content of the parameter to the + // output A Var node can only produce a single output + Array sids = storage_device_map_[expr]; + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), + static_cast((sids[0][0].as())->value)); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]); + } + } + + void VisitExpr_(const TupleNode* op) override { + for (auto field : op->fields) { + VisitExpr(field); + } + } + + void VisitExpr_(const LetNode* op) override { + // TODO(giuseros): support Let nodes in AOT + CHECK(false) << "Let not yet implemented in AOT"; + } + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } + void VisitExpr_(const OpNode* op) override { + throw std::runtime_error("can not compile op in non-eta expanded form"); + } + void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } + void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } + void VisitExpr_(const FunctionNode* op) override { + ICHECK(op->GetAttr(attr::kCompiler).defined()) + << "FunctionNode only supported by custom codegen"; + } + void VisitExpr_(const RefCreateNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefReadNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const RefWriteNode* op) override { + throw std::invalid_argument("reference not supported"); + } + void VisitExpr_(const ConstructorNode* op) override { + throw std::invalid_argument("ADT constructor case not yet implemented"); + } + void VisitExpr_(const MatchNode* op) override { + throw std::invalid_argument("match case not yet implemented"); + } + + // Create the main PrimFunc to execute the graph + tir::PrimFunc CreateMainFunc(unsigned int relay_params) { + tir::Stmt body = tir::SeqStmt(stmts_); + + // Allocate the sids + std::unordered_map allocated; + + for (auto kv : storage_device_map_) { + // Only allocate sids that are needed + const bool is_input = + (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); + const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end()); + if (is_input || is_param) { + continue; + } + + for (unsigned int i = 0; i < kv.second[0].size(); i++) { + int size = kv.second[2][i]; + int sid = static_cast((kv.second[0][i].as())->value); + + if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { + continue; + } + + // TODO(giuseros): we should allocate this once outside the PrimFunc + // so we don't pay the price of allocation for every inference + if (!allocated[sid]) { + body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); + } + allocated[sid] = true; + } + } + + // Define the attributes + body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body); + body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body); + + // Define the PrimFunc attributes + Map dict_attrs; + dict_attrs.Set("global_symbol", runtime::String(runtime::symbol::tvm_run_func_prefix)); + + // Make the PrimFunc + return tir::PrimFunc(main_signature_, body, VoidType(), Map(), + DictAttrs(dict_attrs)); + } + + protected: + /*! \brief mod */ + runtime::Module* mod_; + /*! \brief list of input expressions (i.e., variable passed by the user) */ + std::vector input_vars_; + /*! \brief input and output variables belonging to the main function signature */ + Array main_signature_; + /*! \brief target device */ + TargetsMap targets_; + /*! \brief target host */ + Target target_host_; + + /*! + * \brief parameters (i.e. ConstantNodes found in the graph). + * These are take as inputs to the GraphRuntime. + * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be + * used to lookup the parameter. + */ + std::unordered_map params_; + /*! \brief mapping between expression and parameters */ + Map params_by_expr_; + /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/ + std::unordered_map param_storage_ids_; + + /*! \brief plan memory of device result */ + Map> storage_device_map_; + std::unordered_map sids_table_; + /*! \brief lowered funcs */ + std::unordered_map lowered_funcs_; + /*! \brief compile engine */ + CompileEngine compile_engine_; + /*! \brief the set of statements that make the program */ + std::vector stmts_; + /*! \brief the list of return sids (note that the function might return more then one output */ + IntegerArray return_sid_; + + public: + AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) + : mod_(mod), return_sid_() { + compile_engine_ = CompileEngine::Global(); + targets_ = targets; + target_host_ = target_host; + } + + LoweredOutput Codegen(relay::Function func) { + // Get the module, storage map and token sizes + auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); + storage_device_map_ = (*pf)(func); + + int input_index = 0; + for (auto input : func->params) { + input_vars_.push_back(input); + main_signature_.push_back(tir::Var(MakeString("input_", input_index), DataType::Handle())); + } + + // Define the storage allocator ids + for (auto kv : storage_device_map_) { + for (const auto& sid : kv.second[0]) { + te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + sids_table_[sid] = sid_var; + } + } + + // Find the return sid + return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func); + for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { + main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle())); + } + + VisitExpr(func->body); + + auto prim_func = CreateMainFunc(func->params.size()); + LoweredOutput ret; + + ret.params = std::unordered_map>(); + for (auto param : params_) { + ret.params.emplace(std::make_pair( + param.first, + std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + } + + for (auto& kv : lowered_funcs_) { + if (ret.lowered_funcs.count(kv.first) == 0) { + ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); + } + auto& mod = ret.lowered_funcs[kv.first]; + mod->Update(kv.second); + ret.lowered_funcs.Set(kv.first, mod); + } + ret.external_mods = compile_engine_->LowerExternalFunctions(); + + auto target_host_str = target_host_->str(); + if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_str]->Add( + GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + } else { + Map symbol_map; + symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map)); + } + + ret.metadata = + runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot); + return ret; + } +}; + +class AOTExecutorCodegenModule : public runtime::ModuleNode { + public: + AOTExecutorCodegenModule() {} + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map targets = args[1]; + init(mod, targets); + }); + } else if (name == "codegen") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Function func = args[0]; + this->output_ = codegen(func); + }); + } else if (name == "list_params_name") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); + } else if (name == "get_param_by_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + *rv = get_param_by_name(key); + }); + } else if (name == "get_param_id") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + *rv = get_param_id(key); + }); + } else if (name == "get_irmodule") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); + } else if (name == "get_external_modules") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); + } else if (name == "get_metadata") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); + } else { + return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); + } + } + + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } + + private: + void init(void* mod, Map tmp) { + TargetsMap targets; + Target target_host; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + if (!target_host.defined() && it.second->kind->device_type == kDLCPU) { + target_host = it.second; + } + ICHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = std::make_shared(reinterpret_cast(mod), + targets, target_host); + } + + LoweredOutput codegen(Function func) { return this->codegen_->Codegen(func); } + + Array list_params_name() { + Array ret; + for (const auto& kv : this->output_.params) { + ret.push_back(kv.first); + } + return ret; + } + + runtime::NDArray get_param_by_name(String key) { + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + return (*it).second.second; + } + + Array get_external_modules() { return output_.external_mods; } + + int get_param_id(String key) { + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + return (*it).second.first; + } + + Map get_irmodule() { return this->output_.lowered_funcs; } + + std::shared_ptr codegen_; + LoweredOutput output_; +}; + +runtime::Module CreateAOTExecutorCodegenMod() { + auto ptr = make_object(); + return runtime::Module(ptr); +} + +TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f93ac394230c..71f19a1c21bc 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -48,7 +48,6 @@ using namespace tvm::relay::transform; /*! * \brief Output of building module - * */ struct BuildOutput { std::string graph_json; @@ -56,31 +55,12 @@ struct BuildOutput { std::unordered_map params; }; -/*! - * \brief GraphCodegen module wrapper - * - */ -struct GraphCodegen { - public: - GraphCodegen() { - auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen"); - mod = (*pf)(); - } - ~GraphCodegen() {} - +struct ExecutorCodegen { void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func) { CallFunc("codegen", func); } - std::string GetJSON() { return CallFunc("get_graph_json", nullptr); } - - Array GetExternalModules() { - return CallFunc>("get_external_modules", nullptr); - } - - Map GetIRModule() { - return CallFunc>("get_irmodule", nullptr); - } + virtual void UpdateOutput(BuildOutput* ret) = 0; std::unordered_map GetParams() { std::unordered_map ret; @@ -104,6 +84,17 @@ struct GraphCodegen { return ret; } + Array GetExternalModules() { + return CallFunc>("get_external_modules", nullptr); + } + + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); + } + + runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } + virtual ~ExecutorCodegen() {} + protected: tvm::runtime::Module mod; template @@ -119,6 +110,48 @@ struct GraphCodegen { } }; +struct AOTCodegen : ExecutorCodegen { + AOTCodegen() { + auto pf = GetPackedFunc("relay.build_module._AOTExecutorCodegen"); + mod = (*pf)(); + } + + void UpdateOutput(BuildOutput* ret) override { ret->graph_json = ""; } + + ~AOTCodegen() {} +}; + +/*! + * \brief GraphCodegen module wrapper + * + */ +struct GraphCodegen : ExecutorCodegen { + GraphCodegen() { + auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen"); + mod = (*pf)(); + } + void UpdateOutput(BuildOutput* ret) override { ret->graph_json = GetGraphJSON(); } + + std::string GetGraphJSON() { return CallFunc("get_graph_json", nullptr); } + + ~GraphCodegen() {} +}; + +/*! + * \brief Executor codegen factory function + */ +std::unique_ptr MakeExecutorCodegen(String executor_str) { + std::unique_ptr ret; + if (executor_str == runtime::kTvmExecutorGraph) { + ret = std::make_unique(); + } else if (executor_str == runtime::kTvmExecutorAot) { + ret = std::make_unique(); + } else { + CHECK(false) << "Executor " << executor_str << " not supported"; + } + return ret; +} + /*! * \brief Relay build module * @@ -140,8 +173,8 @@ class RelayBuildModule : public runtime::ModuleNode { [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 3); - this->Build(args[0], args[1], args[2]); + ICHECK_EQ(args.num_args, 4); + this->Build(args[0], args[1], args[2], args[3]); }); } else if (name == "list_params") { return PackedFunc( @@ -158,11 +191,11 @@ class RelayBuildModule : public runtime::ModuleNode { }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetIRModule(); + *rv = this->executor_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalModules(); + *rv = this->executor_codegen_->GetExternalModules(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -237,10 +270,12 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + const String executor) { // Create protected variable targets_ from ground up targets_ = targets; target_host_ = target_host; + executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_); // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. @@ -477,23 +512,23 @@ class RelayBuildModule : public runtime::ModuleNode { // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); + // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. - graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, targets_); - graph_codegen_->Codegen(func); + executor_codegen_ = MakeExecutorCodegen(executor_); + executor_codegen_->Init(nullptr, targets_); + executor_codegen_->Codegen(func); + executor_codegen_->UpdateOutput(&ret_); + ret_.params = executor_codegen_->GetParams(); - ret_.graph_json = graph_codegen_->GetJSON(); - ret_.params = graph_codegen_->GetParams(); - - auto lowered_funcs = graph_codegen_->GetIRModule(); + auto lowered_funcs = executor_codegen_->GetIRModule(); // Generate a placeholder function that attaches linked params as its arguments. if (target_host->GetAttr("link-params").value_or(Bool(false))) { CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; - auto param_ids = graph_codegen_->GetParamIds(); + auto param_ids = executor_codegen_->GetParamIds(); auto link_params = Map(); for (auto param : ret_.params) { link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second)); @@ -527,8 +562,9 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::build(lowered_funcs, target_host_); } - auto ext_mods = graph_codegen_->GetExternalModules(); - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost()); + auto ext_mods = executor_codegen_->GetExternalModules(); + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), + executor_codegen_->GetMetadata()); } private: @@ -546,7 +582,7 @@ class RelayBuildModule : public runtime::ModuleNode { } protected: - std::unique_ptr graph_codegen_; + std::unique_ptr executor_codegen_; /*! \brief target device */ TargetsMap targets_; /*! \brief target host device */ @@ -555,6 +591,12 @@ class RelayBuildModule : public runtime::ModuleNode { std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; + /*! + * \brief Executor used to execute the model: + * - graph: use the json graph executor + * - aot: use the aot executor + */ + String executor_; }; runtime::Module RelayBuildCreate() { diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 8b8ff287eed4..2e36dc6a76c7 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -52,14 +52,6 @@ using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; using TargetsMap = std::unordered_map; -/*! \brief Lowered outputs */ -struct LoweredOutput { - std::string graph_json; - Map lowered_funcs; - Array external_mods; - std::unordered_map> params; -}; - /*! \brief Node types */ enum GraphNodeType { kGraphNop, @@ -251,7 +243,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator storage_info; for (auto& v : storage_device_info[0]) { @@ -648,6 +640,9 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.external_mods; }); + } else if (name == "get_metadata") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); } else { return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index cf843236da61..351469d6e1ca 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -209,6 +209,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; std::vector device_types; + std::vector sid_sizes_byte; for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; @@ -216,8 +217,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor { num_nodes++; storage_ids.push_back(tok->storage_id); device_types.push_back(tok->device_type); + sid_sizes_byte.push_back(GetMemorySize(tok)); } - smap.Set(GetRef(kv.first), Array({storage_ids, device_types})); + smap.Set(GetRef(kv.first), + Array({storage_ids, device_types, sid_sizes_byte})); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6908ca85f582..c804768c99af 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -37,12 +37,27 @@ #include #include #include +#include #include +#include "../../runtime/meta_data.h" + namespace tvm { namespace relay { namespace backend { +/*! + * \brief Executor generator artifacts. Those artifacts are subsequently + * used by the relay build process. + */ +struct LoweredOutput { + std::string graph_json; + Map lowered_funcs; + Array external_mods; + std::unordered_map> params; + runtime::Metadata metadata; +}; + /*! * \brief A helper to expand the params by adding the ones used in a given expression. */ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e562bf242ac9..afc01aab53c2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1189,7 +1189,7 @@ void VMCompiler::Codegen() { // to make sure a DSO module will be also available. lib = codegen::CSourceModuleCreate(";", "", Array{}); } - lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_); + lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata()); exec_->SetLib(lib); CompileEngine::Global()->Clear(); } diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index 8d3acab1858b..38c53d273a6e 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -68,6 +68,7 @@ endef LIBS = \ src/runtime/crt/common \ src/runtime/crt/graph_executor \ + src/runtime/crt/aot_executor \ src/runtime/crt/graph_executor_module \ src/runtime/crt/memory \ src/runtime/crt/utvm_rpc_common \ diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c new file mode 100644 index 000000000000..3880493d1780 --- /dev/null +++ b/src/runtime/crt/aot_executor/aot_executor.c @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Main entry point for + * \param model Model descriptor structure to reference for runtime information + * \param inputs Pointer to input pointer(s) + * \param outputs Pointer to output pointer(s) + * \param context Context information to be passed through to operators + * \return tvm_status_t containing success or errors from the model run + */ +#include +#include + +tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs) { + static DLDevice fake_device = {kDLCPU, 0}; + static int64_t fake_dims = 0; + static int64_t fake_shape = {0}; + + DLTensor tensors[model->num_input_tensors + model->num_output_tensors]; // NOLINT + TVMValue tvm_values[model->num_input_tensors + model->num_output_tensors]; // NOLINT + int32_t tvm_typeids[model->num_input_tensors + model->num_output_tensors]; // NOLINT + + for (int i = 0; i < model->num_input_tensors; i++) { + tensors[i] = (DLTensor){ + .device = fake_device, + .data = inputs[i], + .shape = &fake_shape, + .ndim = fake_dims, + .byte_offset = 0, + .strides = NULL, + }; + tvm_values[i].v_handle = &tensors[i]; + } + + for (int i = 0; i < model->num_output_tensors; i++) { + tensors[model->num_input_tensors + i] = (DLTensor){ + .device = fake_device, + .data = outputs[i], + .shape = &fake_shape, + .ndim = fake_dims, + .byte_offset = 0, + .strides = NULL, + }; + tvm_values[model->num_input_tensors + i].v_handle = &tensors[model->num_input_tensors + i]; + } + + return model->run_func(tvm_values, tvm_typeids, 0, NULL, 0, NULL); +} diff --git a/src/runtime/crt/common/crt_backend_api.c b/src/runtime/crt/common/crt_backend_api.c index 9a12bc28240a..56bbbedc1d64 100644 --- a/src/runtime/crt/common/crt_backend_api.c +++ b/src/runtime/crt/common/crt_backend_api.c @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include "crt_config.h" diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index f73449829bd6..f34bbd4fec95 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -31,8 +31,6 @@ #include #include #include -#include -#include #include // Handle internal errors diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c index fb8fc8022f43..c97f7658938f 100644 --- a/src/runtime/crt/common/ndarray.c +++ b/src/runtime/crt/common/ndarray.c @@ -25,7 +25,7 @@ */ #include -#include +#include #include #include "crt_config.h" diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index 67e0608ab696..907559421e5d 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -51,4 +51,7 @@ /*! \brief DLDataType for the return value from strlen */ #define TVM_CRT_STRLEN_DLTYPE 10 +/*! \brief Enable checks to enforce the stack allocator with a FIFO ordering. Off by default */ +// #define TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK + #endif // TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 2fe9e73aeddc..bf64096441be 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -27,9 +27,9 @@ #include #include #include -#include #include #include +#include #include "crt_config.h" diff --git a/src/runtime/crt/graph_executor/load_json.c b/src/runtime/crt/graph_executor/load_json.c index dd2faecdc538..f1c1f6768168 100644 --- a/src/runtime/crt/graph_executor/load_json.c +++ b/src/runtime/crt/graph_executor/load_json.c @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include // the node entry structure in serialized format diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/crt/host/crt_config.h index b81a74eb4ae6..b0a68c939070 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/crt/host/crt_config.h @@ -51,6 +51,9 @@ /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 +/*! \brief Enable checks to enforce the stack allocator with a FIFO ordering. */ +#define TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK + // #define TVM_CRT_FRAMER_ENABLE_LOGS #endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index e2e4672cbc9d..0b0c81169756 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -123,7 +123,8 @@ int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValu int main(int argc, char** argv) { g_argv = argv; - int status = MemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); + int status = + PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); if (status != 0) { fprintf(stderr, "error initiailizing memory manager\n"); return 2; diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h new file mode 100644 index 000000000000..e49ca9933116 --- /dev/null +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief TVM Executor for the Ahead-of-Time Runtime + * + * AOT models are described by the TVM model descriptor format + * which can be passed to tvm_runtime_run. These descriptors will be + * generated by the AOT compilation process. This can optionally be + * augmented with platform specific context to be passed to the TVM + * operators. + * + * Example: + * extern tvm_model_t my_network; + * int main() { + * void* data = get_data(); + * void* output[4] = {0, 0, 0, 0}; + * void* inputs = {data}; + * void* outputs = {output}; + * tvm_context_t my_context = { + * .driver = ...; + * }; + * tvm_runtime_run( + * &my_network, + * inputs, + * outputs + * &my_context + * ); + * return 0; + * } + */ + +#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief TVM Model descriptor to describe the + * model to the runtime. + */ +typedef struct { + uint32_t num_input_tensors; /** Number of expected input tensors */ + uint32_t num_output_tensors; /** Number of expected output tensors */ + TVMBackendPackedCFunc run_func; /** Generated model function, called through tvm_runtime_run */ +} tvm_model_t; + +/*! + * \brief Main entry point to execute the AOT runner function + * \param model Model descriptor structure to reference for runtime information + * \param inputs Pointer to input pointer(s) + * \param outputs Pointer to output pointer(s) + * \return tvm_status_t containing success or errors from the model run + */ +tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/memory/memory.h b/src/runtime/crt/include/tvm/runtime/crt/internal/memory/page_allocator.h similarity index 94% rename from src/runtime/crt/include/tvm/runtime/crt/internal/memory/memory.h rename to src/runtime/crt/include/tvm/runtime/crt/internal/memory/page_allocator.h index aae045a0f24d..7d40c03f2673 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/memory/memory.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/memory/page_allocator.h @@ -18,17 +18,17 @@ */ /*! - * \file runtime/crt/include/tvm/runtime/crt/internal/memory/memory.h + * \file runtime/crt/include/tvm/runtime/crt/internal/memory/page_allocator.h * \brief Defines data types and functions used in the internal memory manager. * Exposed for testing. */ -#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_MEMORY_H_ -#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_MEMORY_H_ +#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_PAGE_ALLOCATOR_H_ +#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_PAGE_ALLOCATOR_H_ #include #include -#include +#include #include "crt_config.h" @@ -109,4 +109,4 @@ typedef struct MemoryManager { } // extern "C" #endif -#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_MEMORY_H_ +#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_MEMORY_PAGE_ALLOCATOR_H_ diff --git a/src/runtime/crt/memory/memory.c b/src/runtime/crt/memory/page_allocator.c similarity index 92% rename from src/runtime/crt/memory/memory.c rename to src/runtime/crt/memory/page_allocator.c index ed18544c2181..c016fe2acbef 100644 --- a/src/runtime/crt/memory/memory.c +++ b/src/runtime/crt/memory/page_allocator.c @@ -33,9 +33,8 @@ #include #include #include -#include +#include #include -#include #include // construct a new page @@ -123,8 +122,8 @@ void MultiMap_Insert(struct MultiMap* map, uint32_t npage, Page* p) { * \param size The size of memory * \return The virtual address */ -tvm_crt_error_t MemoryManager_Allocate(MemoryManagerInterface* interface, size_t num_bytes, - DLDevice dev, void** out_ptr) { +tvm_crt_error_t PageMemoryManager_Allocate(MemoryManagerInterface* interface, size_t num_bytes, + DLDevice dev, void** out_ptr) { MemoryManager* mgr = (MemoryManager*)interface; *out_ptr = 0; @@ -170,8 +169,8 @@ tvm_crt_error_t MemoryManager_Allocate(MemoryManagerInterface* interface, size_t * \param num_bytes The size of memory now required. * \return kTvmErrorNoError on success. */ -tvm_crt_error_t MemoryManager_Realloc(MemoryManagerInterface* interface, void** ptr, - tvm_index_t num_bytes) { +tvm_crt_error_t PageMemoryManager_Realloc(MemoryManagerInterface* interface, void** ptr, + tvm_index_t num_bytes) { MemoryManager* mgr = (MemoryManager*)interface; uint8_t* data = *((uint8_t**)ptr); // NOLINT(*) @@ -259,7 +258,7 @@ tvm_crt_error_t MemoryManager_Realloc(MemoryManagerInterface* interface, void** * \param dev Execution device passed to TVMPlatformMemoryAllocate. Fixed to {kDLCPU, 0}. * \return kTvmErrorNoError if successful; a descriptive error code otherwise. */ -tvm_crt_error_t MemoryManager_Free(MemoryManagerInterface* interface, void* ptr, DLDevice dev) { +tvm_crt_error_t PageMemoryManager_Free(MemoryManagerInterface* interface, void* ptr, DLDevice dev) { MemoryManager* mgr = (MemoryManager*)interface; TLB* pmap = &(mgr->pmap); @@ -278,8 +277,9 @@ tvm_crt_error_t MemoryManager_Free(MemoryManagerInterface* interface, void* ptr, return kTvmErrorNoError; } -tvm_crt_error_t MemoryManagerCreate(MemoryManagerInterface** interface, uint8_t* memory_pool, - size_t memory_pool_size_bytes, size_t page_size_bytes_log2) { +tvm_crt_error_t PageMemoryManagerCreate(MemoryManagerInterface** interface, uint8_t* memory_pool, + size_t memory_pool_size_bytes, + size_t page_size_bytes_log2) { memset(memory_pool, 0, sizeof(memory_pool_size_bytes)); // Allocate enough space for MAX_PAGES. @@ -292,14 +292,14 @@ tvm_crt_error_t MemoryManagerCreate(MemoryManagerInterface** interface, uint8_t* MemoryManager* manager = (MemoryManager*)metadata_cursor; *interface = &manager->interface; /* handle MemoryManager member functions */ - manager->interface.Allocate = MemoryManager_Allocate; + manager->interface.Allocate = PageMemoryManager_Allocate; // manager->Realloc = MemoryManager_Reallocate; - manager->interface.Free = MemoryManager_Free; + manager->interface.Free = PageMemoryManager_Free; metadata_cursor += sizeof(MemoryManager); - manager->interface.Allocate = MemoryManager_Allocate; - manager->interface.Free = MemoryManager_Free; + manager->interface.Allocate = PageMemoryManager_Allocate; + manager->interface.Free = PageMemoryManager_Free; manager->ptable.memory_pool = memory_pool; /* handle PageTable member functions */ diff --git a/src/runtime/crt/memory/stack_allocator.c b/src/runtime/crt/memory/stack_allocator.c new file mode 100644 index 000000000000..6722816ec538 --- /dev/null +++ b/src/runtime/crt/memory/stack_allocator.c @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// LINT_C_FILE +#include +#ifdef TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK +#include +#endif + +tvm_crt_error_t StackMemoryManager_Allocate(tvm_workspace_t* tvm_runtime_workspace, int32_t nbytes, + void** current_alloc) { + // reserve bytes at the end of the allocation such that + // next_alloc % TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES == 0. + uint32_t offset_bytes = + (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - nbytes) & (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - 1); + uint8_t* workspace_end = tvm_runtime_workspace->workspace + tvm_runtime_workspace->workspace_size; + if (tvm_runtime_workspace->next_alloc + nbytes + offset_bytes > workspace_end) { + return kTvmErrorPlatformNoMemory; + } + (*current_alloc) = tvm_runtime_workspace->next_alloc; + uint8_t* next_alloc = tvm_runtime_workspace->next_alloc + nbytes + offset_bytes; +#ifdef TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK + if (next_alloc + STACK_ALLOCATOR_TAG_SIZE_BYTES > workspace_end) { + return kTvmErrorPlatformNoMemory; + } + const uint32_t total_size = (nbytes + offset_bytes + STACK_ALLOCATOR_TAG_SIZE_BYTES); + *((uint32_t*)next_alloc) = total_size ^ STACK_ALLOCATOR_TAG; + next_alloc += STACK_ALLOCATOR_TAG_SIZE_BYTES; +#endif + + tvm_runtime_workspace->next_alloc = next_alloc; + return kTvmErrorNoError; +} + +tvm_crt_error_t StackMemoryManager_Free(tvm_workspace_t* tvm_runtime_workspace, void* ptr) { +#ifdef TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK + uint32_t tag = *(((uint32_t*)tvm_runtime_workspace->next_alloc) - 1); + uint32_t actual_size = (tvm_runtime_workspace->next_alloc - (uint8_t*)ptr); + uint32_t expected_size = tag ^ STACK_ALLOCATOR_TAG; + CHECK_EQ(expected_size, actual_size, "Deallocation not in FIFO ordering"); +#endif + tvm_runtime_workspace->next_alloc = ptr; + return kTvmErrorNoError; +} + +tvm_crt_error_t StackMemoryManager_Init(tvm_workspace_t* tvm_runtime_workspace, + uint8_t* g_aot_memory, size_t workspace_size) { + tvm_runtime_workspace->next_alloc = g_aot_memory; + tvm_runtime_workspace->workspace = g_aot_memory; + tvm_runtime_workspace->workspace_size = workspace_size; + return kTvmErrorNoError; +} diff --git a/src/runtime/crt/utvm_rpc_server/rpc_server.cc b/src/runtime/crt/utvm_rpc_server/rpc_server.cc index 8b7c0eb01840..1736f98dad12 100644 --- a/src/runtime/crt/utvm_rpc_server/rpc_server.cc +++ b/src/runtime/crt/utvm_rpc_server/rpc_server.cc @@ -35,8 +35,8 @@ #include #include #include -#include #include +#include #include #include #include diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 03dba399fcb4..495b3f22e6ad 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -26,12 +26,14 @@ #include #include +#include #include #include #include #include #include +#include #include #include "runtime_base.h" @@ -39,6 +41,40 @@ namespace tvm { namespace runtime { +/*! + * \brief Structure that can be optionally used by the executor codegen + */ +class MetadataNode : public Object { + public: + /*! \brief number of inputs of the main function */ + int num_inputs = 1; + /*! \brief number of outputs of the main function */ + int num_outputs = 1; + /*! \brief the executor to be used to run the model */ + String executor = kTvmExecutorGraph; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "MetadataObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, Object); +}; + +/*! + * \brief Managed reference to MetadataNode. + */ +class Metadata : public ObjectRef { + public: + TVM_DLL Metadata(int num_inputs, int num_outputs, String executor) { + auto n = make_object(); + n->num_inputs = num_inputs; + n->num_outputs = num_outputs; + n->executor = executor; + data_ = std::move(n); + } + + TVM_DEFINE_OBJECT_REF_METHODS(Metadata, ObjectRef, MetadataNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MetadataNode); +}; + /*! * \brief Create a metadata module object. * diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 8184e9189c4b..db4051e00fd2 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -21,7 +21,6 @@ * \file metadata_module.cc * \brief Defines functions that build MetadataModules for C++ and C runtimes. */ - #include "metadata_module.h" #include @@ -46,7 +45,8 @@ namespace codegen { */ runtime::Module CreateMetadataModule( const std::unordered_map& params, - tvm::runtime::Module target_module, const Array& ext_modules, Target target) { + tvm::runtime::Module target_module, const Array& ext_modules, Target target, + runtime::Metadata metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). @@ -114,7 +114,7 @@ runtime::Module CreateMetadataModule( if (target->kind->name == "c") { crt_exportable_modules.push_back(target_module); - target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target); + target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, metadata); } else if (target->kind->name == "llvm") { #ifdef TVM_LLVM_VERSION crt_exportable_modules.push_back(target_module); diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index 83cb29dd5a46..add05ba52692 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -33,12 +33,15 @@ #include #include +#include "../runtime/meta_data.h" + namespace tvm { namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, - tvm::runtime::Module target_module, const Array& ext_modules, Target target); + tvm::runtime::Module target_module, const Array& ext_modules, Target target, + runtime::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index b643459be4b6..1627b6003391 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -852,8 +852,11 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; const VarNode* buffer = op->buffer_var.as(); - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); + auto it = alloc_storage_scope_.find(buffer); + if (it != alloc_storage_scope_.end()) { + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); + } PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index af4bb48d1d73..0bfbade23f01 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -234,6 +234,54 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "}\n"; } +void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args) { + this->PrintIndent(); + std::string ret_val = GetUniqueName("ret_val"); + std::string ret_type_code = GetUniqueName("ret_type_code"); + this->stream << "TVMValue " << ret_val << ";\n"; + this->PrintIndent(); + this->stream << "int " << ret_type_code << ";\n"; + this->PrintIndent(); + + this->stream << "if (" << packed_func_name << "( " + << "(TVMValue*) stack_value " + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ", NULL) != 0){\n"; + + int func_call_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "return -1;\n"; + this->EndScope(func_call_scope); + this->PrintIndent(); + this->stream << "}\n"; +} + +CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) { + const StringImmNode* s = op->args[0].as(); + ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; + int64_t begin = op->args[3].as()->value; + int64_t end = op->args[4].as()->value; + int64_t num_args = end - begin; + ICHECK_GE(num_args, 0); + std::string func_name = s->value; + // NOTE: cannot rely on GetUnique for global decl_stream declarations + // because it is reset between AddFunction(). + std::string packed_func_name = func_name + "_packed"; + std::string unique_name; + auto it = declared_globals_.find(packed_func_name); + if (it != declared_globals_.end()) { + unique_name = it->second; + } else { + unique_name = GetUniqueName(packed_func_name); + declared_globals_[packed_func_name] = unique_name; + decl_stream << "static void* " << unique_name << " = NULL;\n"; + } + return {func_name, unique_name, num_args}; +} + void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = GetUniqueName("stack"); @@ -258,27 +306,12 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; os << stack_name; } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { - const StringImmNode* s = op->args[0].as(); - ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; - int64_t begin = op->args[3].as()->value; - int64_t end = op->args[4].as()->value; - int64_t num_args = end - begin; - ICHECK_GE(num_args, 0); - std::string func_name = s->value; - // NOTE: cannot rely on GetUnique for global decl_stream declarations - // because it is reset between AddFunction(). - std::string packed_func_name = func_name + "_packed"; - std::string unique_name; - auto it = declared_globals_.find(packed_func_name); - if (it != declared_globals_.end()) { - unique_name = it->second; - } else { - unique_name = GetUniqueName(packed_func_name); - declared_globals_[packed_func_name] = unique_name; - decl_stream << "static void* " << unique_name << " = NULL;\n"; - } - this->PrintGetFuncFromBackend(func_name, unique_name); - this->PrintFuncCall(unique_name, num_args); + auto function_info = GetFunctionInfo(op); + this->PrintGetFuncFromBackend(function_info.func_name, function_info.func_name_packed); + this->PrintFuncCall(function_info.func_name_packed, function_info.num_args); + } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { + auto function_info = GetFunctionInfo(op); + this->PrintFuncCallC(function_info.func_name, function_info.num_args); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; @@ -336,6 +369,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) { Map linked_params; bool found_linked_params = false; bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); + PrimFunc aot_executor_fn; + for (auto kv : mod->functions) { if (could_have_linked_params && kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { @@ -347,6 +382,16 @@ runtime::Module BuildCHost(IRModule mod, Target target) { found_linked_params = true; continue; } + // Make sure that the executor function is the last one to be code generated so that all the + // symbols are available to tvm_run_func + auto fun_name = std::string(kv.first->name_hint); + const bool is_aot_executor_fn = + (fun_name.rfind(::tvm::runtime::symbol::tvm_run_func_prefix, 0) == 0); + + if (is_aot_executor_fn) { + aot_executor_fn = Downcast(kv.second); + continue; + } ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); @@ -358,6 +403,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { cg.LinkParameters(linked_params); } + if (aot_executor_fn.defined()) { + cg.AddFunction(aot_executor_fn); + } + if (target->GetAttr("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index eace09f13a07..2ee31b8c7e0e 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -62,6 +62,15 @@ class CodeGenCHost final : public CodeGenC { Array GetFunctionNames() { return function_names_; } private: + /* \brief Internal structure to store information about function calls */ + struct FunctionInfo { + /* \brief function name */ + std::string func_name; + /* packed name of the function */ + std::string func_name_packed; + /* number of arguments required by the function */ + int64_t num_args; + }; std::string module_name_; /* \brief mapping global packed func to the unique name */ std::unordered_map declared_globals_; @@ -70,8 +79,10 @@ class CodeGenCHost final : public CodeGenC { /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; + FunctionInfo GetFunctionInfo(const CallNode* op); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); void PrintFuncCall(const std::string& packed_func_name, int num_args); + void PrintFuncCallC(const std::string& packed_func_name, int num_args); /*! * \brief Print ternary conditional operator implementing binary `op` diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 3baa44eb639f..ff0d079f5425 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -155,7 +156,7 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, */ runtime::Module CreateMetadataModule( const std::unordered_map& params, runtime::Module target_module, - const Array& ext_modules, Target target); + const Array& ext_modules, Target target, runtime::Metadata metadata); /*! * \brief Create a source module for viewing and limited saving for device. @@ -173,10 +174,11 @@ runtime::Module DeviceSourceModuleCreate( * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. * \param modules The modules to be wrapped. * \param target the target the modules are compiled for. + * \param metadata the metadata needed for code generation. * \return The wrapped module. */ -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - Target target); +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 26f1850c0e47..661df9305036 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -130,8 +130,8 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target) - : fmt_(fmt), func_names_(func_names), target_(target) { + Target target, runtime::Metadata metadata) + : fmt_(fmt), func_names_(func_names), target_(target), metadata_(metadata) { CreateSource(); } const char* type_key() const { return "c"; } @@ -159,6 +159,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::string fmt_; Array func_names_; Target target_; + runtime::Metadata metadata_; void CreateFuncRegistry() { code_ << "#include \n"; @@ -191,17 +192,36 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } + void GenerateAOTDescriptor() { + code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n"; + code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + code_ << "#ifdef __cplusplus\n"; + code_ << "extern \"C\"\n"; + code_ << "#endif\n"; + code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix; + code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " + "out_type_code, void* resource_handle);\n"; + code_ << "const tvm_model_t network = {\n" + << " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n" + << " .num_input_tensors = " << metadata_->num_inputs << ",\n" + << " .num_output_tensors = " << metadata_->num_outputs << ", \n" + << "};\n"; + } + void CreateSource() { if (target_->GetAttr("system-lib").value_or(Bool(false)) && !func_names_.empty()) { CreateFuncRegistry(); GenerateCrtSystemLib(); } + if (metadata_.defined() && metadata_->executor == runtime::kTvmExecutorAot) { + GenerateAOTDescriptor(); + } code_ << ";"; } }; -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - Target target) { +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + runtime::Metadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -212,7 +232,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod } } } - auto n = make_object(func_names, "cc", target); + auto n = make_object(func_names, "cc", target, metadata); auto csrc_metadata_module = runtime::Module(n); for (const auto& mod : modules) { csrc_metadata_module.Import(mod); @@ -283,7 +303,8 @@ TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target) { - return CreateCSourceCrtMetadataModule(modules, target); + // Note that we don't need metadata when we compile a single operator + return CreateCSourceCrtMetadataModule(modules, target, runtime::Metadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 45858b9f4ef2..6226ba2f22b3 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -29,6 +29,8 @@ #include #include +#include "../../runtime/meta_data.h" + namespace tvm { namespace codegen { @@ -38,7 +40,7 @@ namespace codegen { * \param target TVM target. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - tvm::Target target); + tvm::Target target, runtime::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 08842554257b..474b1b0d8ac4 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -227,6 +227,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("runtime") .add_attr_option("mcpu") .add_attr_option("march") + .add_attr_option("executor") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLGPU) @@ -308,8 +309,7 @@ TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break .add_attr_option("system-lib"); -TVM_REGISTER_TARGET_KIND("composite", kDLCPU) - .add_attr_option>("devices"); +TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); /********** Registry **********/ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1117571c8b75..f3ab78f89bec 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -174,6 +174,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -184,6 +187,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 8d2857ef7a40..0e2e612e3ae8 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -179,7 +179,9 @@ class BuiltinLower : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { - return MakeCallPacked(op); + return MakeCallPacked(op, /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::tvm_call_cpacked())) { + return MakeCallPacked(op, /* use_string_lookup */ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { return MakeCallTracePacked(op); } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { @@ -256,7 +258,7 @@ class BuiltinLower : public StmtExprMutator { return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } // call packed. - PrimExpr MakeCallPacked(const CallNode* op) { + PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -297,8 +299,10 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - // call_packed_lowered needs to do the type casting properly - return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args); + + auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() + : builtin::tvm_call_cpacked_lowered(); + return Call(op->dtype, builtin_call, packed_args); } PrimExpr MakeCallTracePacked(const CallNode* op) { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 344fd3d40ba8..314185240563 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -119,7 +120,7 @@ TEST(Relay, BuildModule) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt); + build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 5c642a37d6bc..e674c3b74144 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -91,7 +92,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { Target llvm_tgt = Target("llvm"); targets.Set(0, llvm_tgt); - build_f(func, targets, llvm_tgt); + build_f(func, targets, llvm_tgt, runtime::kTvmExecutorGraph); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); std::string o_fname = std::tmpnam(nullptr); diff --git a/tests/crt/aot_executor_test.cc b/tests/crt/aot_executor_test.cc new file mode 100644 index 000000000000..ded6729d138b --- /dev/null +++ b/tests/crt/aot_executor_test.cc @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +int test_run_func(TVMValue* args, int* arg_type_ids, int num_args, TVMValue* out_ret_value, + int* out_ret_tcode, void* resource_handle) { + return kTvmErrorNoError; +} + +TEST(AOTRuntime, NoOp) { + const tvm_model_t test_model = { + .num_input_tensors = 0, + .num_output_tensors = 0, + .run_func = &test_run_func, + }; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&test_model, NULL, NULL)); +} + +int32_t error_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, TVMValue* out_ret_value, + int* out_ret_tcode, void* resource_handle) { + return kTvmErrorPlatformNoMemory; +} + +TEST(AOTRuntime, Error) { + const tvm_model_t error_model = { + .num_input_tensors = 0, + .num_output_tensors = 0, + .run_func = &error_run_func, + }; + + ASSERT_EQ(kTvmErrorPlatformNoMemory, tvm_runtime_run(&error_model, NULL, NULL)); +} + +int32_t identity_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, + TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_id = (((DLTensor*)arg1)[0].data); + ((uint32_t*)T_id)[(0)] = ((uint32_t*)placeholder)[(0)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, Identity) { + const tvm_model_t identity_model = { + .num_input_tensors = 1, + .num_output_tensors = 1, + .run_func = &identity_run_func, + }; + + uint32_t inputs1[1] = {404}; + void* inputs[] = {inputs1}; + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&identity_model, inputs, outputs)); + ASSERT_EQ(outputs1[0], 404); +} + +int32_t add_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, TVMValue* out_ret_value, + int* out_ret_tcode, void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_add = (((DLTensor*)arg1)[0].data); + ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)]; + return kTvmErrorNoError; + + return kTvmErrorNoError; +} + +TEST(AOTRuntime, Add) { + const tvm_model_t add_model = { + .num_input_tensors = 1, + .num_output_tensors = 1, + .run_func = &add_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + void* inputs[] = {inputs1}; + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&add_model, inputs, outputs)); + ASSERT_EQ(outputs1[0], 904); +} + +int32_t multiple_inputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, + TVMValue* out_ret_value, int* out_ret_tcode, + void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* arg2 = (((TVMValue*)args)[2].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* placeholder1 = (((DLTensor*)arg1)[0].data); + void* T_add = (((DLTensor*)arg2)[0].data); + ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)] + + ((uint32_t*)placeholder1)[(0)] + ((uint32_t*)placeholder1)[(1)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, MultipleInputs) { + const tvm_model_t multiple_inputs_model = { + .num_input_tensors = 2, + .num_output_tensors = 1, + .run_func = &multiple_inputs_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + uint32_t inputs2[2] = {200, 202}; + void* inputs[] = {inputs1, inputs2}; + + uint32_t outputs1[1]; + void* outputs[] = {outputs1}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_inputs_model, inputs, outputs)); + ASSERT_EQ(outputs1[0], 1306); +} + +int32_t multiple_outputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, + TVMValue* out_ret_value, int* out_ret_tcode, + void* resource_handle) { + void* arg0 = (((TVMValue*)args)[0].v_handle); + void* arg1 = (((TVMValue*)args)[1].v_handle); + void* arg2 = (((TVMValue*)args)[2].v_handle); + void* placeholder = (((DLTensor*)arg0)[0].data); + void* T_split1 = (((DLTensor*)arg1)[0].data); + void* T_split2 = (((DLTensor*)arg2)[0].data); + ((uint32_t*)T_split1)[(0)] = ((uint32_t*)placeholder)[(0)]; + ((uint32_t*)T_split2)[(0)] = ((uint32_t*)placeholder)[(1)]; + return kTvmErrorNoError; +} + +TEST(AOTRuntime, MultipleOutputs) { + const tvm_model_t multiple_outputs_model = { + .num_input_tensors = 1, + .num_output_tensors = 2, + .run_func = &multiple_outputs_run_func, + }; + + uint32_t inputs1[2] = {404, 500}; + void* inputs[] = {inputs1}; + + uint32_t outputs1[1]; + uint32_t outputs2[1]; + void* outputs[] = {outputs1, outputs2}; + + ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_outputs_model, inputs, outputs)); + ASSERT_EQ(outputs1[0], 404); + ASSERT_EQ(outputs2[0], 500); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/crt/aot_memory_test.cc b/tests/crt/aot_memory_test.cc new file mode 100644 index 000000000000..ecae2ef52f59 --- /dev/null +++ b/tests/crt/aot_memory_test.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "platform.cc" +/* + * Tests allocations are properly aligned when allocated + */ +TEST(AOTMemory, Allocate) { + static uint8_t model_memory[96]; + tvm_workspace_t tvm_runtime_workspace; + + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 96), kTvmErrorNoError); + void* block_one = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_one), kTvmErrorNoError); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 2, &block_two), kTvmErrorNoError); + ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + + void* two_blocks = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 24, &two_blocks), kTvmErrorNoError); + ASSERT_EQ(two_blocks, &model_memory[32 + 2 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); + + void* block_three = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_three), kTvmErrorNoError); + ASSERT_EQ(block_three, &model_memory[64 + 3 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); +} + +/* + * Tests resetting the stack after dealloc + */ +TEST(AOTMemory, Free) { + static uint8_t model_memory[80]; + tvm_workspace_t tvm_runtime_workspace; + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + + void* block_one = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_one), kTvmErrorNoError); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_two), kTvmErrorNoError); + ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free(&tvm_runtime_workspace, block_two)); + + void* two_blocks = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 2, &two_blocks), kTvmErrorNoError); + ASSERT_EQ(two_blocks, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free(&tvm_runtime_workspace, two_blocks)); + + void* block_three = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_three), kTvmErrorNoError); + ASSERT_EQ(block_three, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); +} + +/* + * Tests we return NULL if we over allocate + */ +TEST(AOTMemory, OverAllocate) { + static uint8_t model_memory[72]; + tvm_workspace_t tvm_runtime_workspace; + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + + void* block_one = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_one), kTvmErrorNoError); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_two), kTvmErrorNoError); + ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + + void* two_blocks = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 64, &two_blocks), + kTvmErrorPlatformNoMemory); + ASSERT_EQ(two_blocks, (void*)NULL); +} + +/* + * Test for out-of-order memory deallocation + */ +TEST(AOTMemory, FreeOutOfOrder) { + static uint8_t model_memory[80]; + tvm_workspace_t tvm_runtime_workspace; + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + + void* block_one = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_one), kTvmErrorNoError); + ASSERT_EQ(block_one, &model_memory[0]); + + void* block_two = NULL; + ASSERT_EQ(StackMemoryManager_Allocate(&tvm_runtime_workspace, 1, &block_two), kTvmErrorNoError); + ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + + ASSERT_EXIT(StackMemoryManager_Free(&tvm_runtime_workspace, block_one), + ::testing::ExitedWithCode(2), ""); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/crt/framing_test.cc b/tests/crt/framing_test.cc index 241e23d877cb..5ee226dc5ee7 100644 --- a/tests/crt/framing_test.cc +++ b/tests/crt/framing_test.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include diff --git a/tests/crt/memory_test.cc b/tests/crt/memory_test.cc index d876e5c96da9..b531383058e6 100644 --- a/tests/crt/memory_test.cc +++ b/tests/crt/memory_test.cc @@ -18,8 +18,8 @@ */ #include -#include -#include +#include +#include #include "crt_config.h" #include "platform.cc" @@ -37,7 +37,7 @@ class MemoryManagerTest : public ::testing::Test { void SetUp() override { memset(raw_memory_pool, 0, sizeof(raw_memory_pool)); memory_pool = (uint8_t*)(ROUND_UP(((uintptr_t)raw_memory_pool), (1 << kPageSizeBytesLog))); - MemoryManagerCreate(&interface, memory_pool, kMemoryPoolSizeBytes, kPageSizeBytesLog); + PageMemoryManagerCreate(&interface, memory_pool, kMemoryPoolSizeBytes, kPageSizeBytesLog); mgr = (MemoryManager*)interface; ASSERT_EQ(kNumUsablePages, mgr->ptable.max_pages); dev_ = {kDLCPU, 0}; diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc index 60686be25060..9840f55dc685 100644 --- a/tests/crt/session_test.cc +++ b/tests/crt/session_test.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index d75e1b607b8d..4da1f12b273a 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -235,7 +235,7 @@ def test_onnx(platform, west_cmd): target = tvm.target.target.micro(model, options=["-link-params=1"]) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered = relay.build(relay_mod, target, params=params) - graph = lowered.get_json() + graph = lowered.get_graph_json() with _make_session(model, target, zephyr_board, west_cmd, lowered.lib) as session: graph_mod = tvm.micro.create_local_graph_executor( diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/aot_test.mk new file mode 100644 index 000000000000..ae8389561459 --- /dev/null +++ b/tests/python/relay/aot/aot_test.mk @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Makefile to build ethosu_test_runner +# Setup build environment +# +AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot + +ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 +DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core +PKG_COMPILE_OPTS = -g +CC = gcc +AR = ar +RANLIB = ranlib +CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) + + +PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ + -I$(TVM_ROOT)/src/runtime/crt/include \ + -I$(TVM_ROOT)/src/runtime/crt/host \ + -I$(TVM_ROOT)/include \ + -I$(DMLC_CORE)/include \ + -I$(TVM_ROOT)/3rdparty/dlpack/include \ + -I$(AOT_ROOT)\ + -I$(build_dir) + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) + +CRT_SRCS = $(shell find $(CRT_ROOT)) + +aot_test_runner: $(build_dir)/aot_test_runner + +source_libs= $(wildcard $(build_dir)/../codegen/host/src/lib*.c) +lib_objs =$(source_libs:.c=.o) + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/aot_executor.o $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm + +$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/aot_executor.o: $(TVM_ROOT)/src/runtime/crt/aot_executor/aot_executor.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +$(build_dir)/crt_backend_api.o: $(TVM_ROOT)/src/runtime/crt/common/crt_backend_api.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + +clean: + $(QUIET)rm -rf $(build_dir)/crt +cleanall: + $(QUIET)rm -rf $(build_dir) +# Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) +.SUFFIXES: diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py new file mode 100644 index 000000000000..8273d3954d3b --- /dev/null +++ b/tests/python/relay/aot/aot_test_utils.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import io +import struct +import numpy as np +import pathlib +import shutil +import subprocess +import tempfile +import tarfile + + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.contrib import utils, graph_executor +from tvm.relay.backend import compile_engine +from tvm.contrib import utils +from tvm.micro import export_model_library_format + + +def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): + """ + This method runs a process and logs the output to both a log file and stdout + """ + with subprocess.Popen( + cmd, cwd=cwd, shell=True, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) as proc, open(logfile, "a") as f: + while True: + data = proc.stdout.readline() + result = proc.poll() + # process is done if there is no data and the result is valid + if data == b"" and result is not None: + return int(result) + if data: + text = data.decode("ascii", errors="backslashreplace") + f.write(text) + if stdout: + print(text, end="") + + +def create_main(test_name, input_list, output_list, output_path): + file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() + # create header file + raw_path = file_path.with_suffix(".c").resolve() + with open(raw_path, "w") as main_file: + main_file.write("#include \n") + main_file.write("#include \n") + main_file.write('#include "tvm/runtime/crt/internal/aot_executor/aot_executor.h"\n') + main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') + main_file.write("#define WORKSPACE_SIZE (16384*1024)\n") + main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") + + for i in range(0, len(input_list)): + main_file.write('#include "input_data%i.h"\n' % i) + for i in range(0, len(output_list)): + main_file.write('#include "expected_output_data%i.h"\n' % i) + main_file.write('#include "output_data%i.h"\n' % i) + + main_file.write("extern tvm_model_t network;\n") + main_file.write("tvm_workspace_t app_workspace;\n") + main_file.write( + """ +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return StackMemoryManager_Free(&app_workspace,ptr); +} + +void TVMPlatformAbort(tvm_crt_error_t code) { } + +void TVMLogf(const char* msg, ...) { } + +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} + + """ + ) + main_file.write("int main(){\n") + main_file.write("void* inputs[%i] = { " % (len(input_list))) + + for i in range(0, len(input_list)): + main_file.write("input_data%i, " % i) + main_file.write("};\n") + + main_file.write("void* outputs[%i] = { " % (len(output_list))) + for i in range(0, len(output_list)): + main_file.write("output_data%i, " % i) + main_file.write("};\n") + + main_file.write("StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);") + main_file.write("tvm_runtime_run(&network, inputs, outputs);") + + for i in range(0, len(output_list)): + is_float_dtype = output_list[i].dtype == "float32" + main_file.write("for (int i = 0; i 0.001f){printf("ko\\n");return -1;}\n' + % (i, i) + ) + else: + main_file.write( + 'if (output_data%s[i]!=expected_output_data%s[i]){printf("ko\\n");return -1;}\n' + % (i, i) + ) + main_file.write("}\n") + + main_file.write('printf("ok\\n");') + main_file.write("return 0;") + main_file.write("}\n") + + +def create_header_file(tensor_name, npy_data, output_path): + """ + This method generates a header file containing the data contained in the numpy array provided. + It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone ethosu_test_runner. + """ + file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() + # create header file + raw_path = file_path.with_suffix(".h").resolve() + with open(raw_path, "w") as header_file: + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + + if npy_data.dtype == "int8": + header_file.write(f"int8_t {tensor_name}[] =") + elif npy_data.dtype == "int32": + header_file.write(f"int32_t {tensor_name}[] = ") + elif npy_data.dtype == "uint8": + header_file.write(f"uint8_t {tensor_name}[] = ") + elif npy_data.dtype == "float32": + header_file.write(f"float {tensor_name}[] = ") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + +def compile_and_run(mod, input_list, output_list, params=None): + """ + This method verifies the generated source + """ + target = "c -runtime=c --link-params --executor=aot" + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build(mod, target, target_host=target, params=params) + + tmp_path = utils.tempdir() + tmp_dir = tmp_path.temp_dir + + base_path = os.path.join(tmp_dir, "test") + build_path = os.path.join(base_path, "build") + os.makedirs(build_path, exist_ok=True) + + tar_file = os.path.join(base_path, "test.tar") + export_model_library_format(lib, tar_file) + t = tarfile.open(tar_file) + t.extractall(base_path) + + for i in range(len(input_list)): + create_header_file((f"input_data{i}"), input_list[i], build_path) + + for i in range(len(output_list)): + create_header_file( + (f"output_data{i}"), + np.zeros(output_list[i].shape, output_list[i].dtype), + build_path, + ) + create_header_file((f"expected_output_data{i}"), output_list[i], build_path) + + create_main("test.c", input_list, output_list, build_path) + + # Verify that compiles fine + file_dir = os.path.dirname(os.path.abspath(__file__)) + makefile = os.path.join(file_dir, "aot_test.mk") + make_cmd = f"make -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + + compile_log_path = os.path.join(build_path, "test_compile.log") + ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + assert ret == 0 + + # Verify that runs fine + run_log_path = os.path.join(build_path, "test_run.log") + ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + assert ret == 0 + + +def generate_ref_data(mod, input_data, params=None, target="llvm"): + """Generate reference data through executing the relay module""" + compile_engine.get().clear() + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + + lib_name = "mod.so" + temp = utils.tempdir() + lib_path = temp.relpath(lib_name) + lib.export_library(lib_path) + lib = tvm.runtime.load_module(lib_path) + grt_mod = graph_executor.GraphModule(lib["default"](tvm.cpu())) + grt_mod.set_input(**input_data) + grt_mod.run() + output_count = grt_mod.get_num_outputs() + out = [grt_mod.get_output(i).asnumpy() for i in range(output_count)] + return out diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py new file mode 100644 index 000000000000..0f1f2ad369e7 --- /dev/null +++ b/tests/python/relay/aot/test_crt_aot.py @@ -0,0 +1,349 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import io +import struct +import numpy as np +import pathlib +import shutil +import subprocess +import tempfile +import tarfile +import pytest + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.op.contrib import get_pattern_table +from tvm.contrib import utils +from tvm.relay.backend import compile_engine +from tvm.contrib import utils +from tvm.contrib import graph_executor +from tvm.micro import export_model_library_format +from tvm.relay import testing +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.contrib import utils +from tvm.relay.expr_functor import ExprMutator + +from aot_test_utils import * + + +def test_conv_with_params(): + RELAY_MODEL = """ +#[version = "0.0.5"] +def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=8, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %1 +} +""" + mod = tvm.parser.fromtext(RELAY_MODEL) + main_func = mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params = {"weight": weight_data} + inputs = {"data": input_data} + output_list = generate_ref_data(mod, inputs, params) + + input_list = [input_data] + compile_and_run(mod, input_list, output_list, params) + + +def test_add_with_params(): + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + y_in = np.random.uniform(size=(1, 10)).astype("float32") + + params = {"x": x_in} + inputs = {"y": y_in} + output_list = generate_ref_data(func, inputs, params) + + input_list = [y_in] + compile_and_run(func, input_list, output_list, params) + + +def test_conv2d(): + """Test a subgraph with a single conv2d operator.""" + + def conv2d_direct(): + dtype = "float32" + ishape = (1, 32, 14, 14) + w1shape = (32, 32, 3, 3) + + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=w1shape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + return mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14) + + def group_conv2d(): + dtype = "float32" + ishape = (1, 32, 14, 14) + w2shape = (32, 1, 3, 3) + + data0 = relay.var("data", shape=(ishape), dtype=dtype) + weight0 = relay.var("weight", shape=(w2shape), dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w_data = np.random.uniform(0, 1, w2shape).astype(dtype) + + return mod, {"data": i_data, "weight": w_data}, (1, 32, 14, 14) + + for mod, inputs, out_shape in [conv2d_direct(), group_conv2d()]: + output_list = generate_ref_data(mod, inputs) + input_list = [inputs["data"], inputs["weight"]] + compile_and_run(mod, input_list, output_list) + + +def test_concatenate(): + dtype = "float32" + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) + z = relay.concatenate((x, y), axis=1) + z = relay.add(z, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + inputs = {"x": x_data, "y": y_data, "z": t_data} + + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"], inputs["y"], inputs["z"]] + compile_and_run(func, input_list, output_list) + + +def test_nested_tuples(): + x = relay.var("x", shape=(10,)) + x1 = x + relay.const(1.0) + x2 = x1 + relay.const(1.0) + x3 = x2 + relay.const(1.0) + x4 = x3 + relay.const(1.0) + out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])]) + func = relay.Function([x], out) + + x_data = np.random.uniform(size=(10,)).astype(np.float32) + inputs = {"x": x_data} + output_list = generate_ref_data(func, inputs) + input_list = [x_data] + compile_and_run(func, input_list, output_list) + + +def test_tuple_getitem(): + func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) + output_list = generate_ref_data(func, {}) + input_list = [] + compile_and_run(func, input_list, output_list) + + +def test_id(): + x = relay.var("x", "float32") + ident = relay.Function([x], x) + one = np.array(1.0, "float32") + inputs = {"x": one} + output_list = generate_ref_data(ident, inputs) + input_list = [one] + compile_and_run(ident, input_list, output_list) + + +def test_add_const(): + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + output_list = generate_ref_data(func, {}) + input_list = [] + compile_and_run(func, input_list, output_list) + + +def test_mul_param(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(1, 10)) + func = relay.Function([x, y], relay.multiply(x, y)) + x_data = np.random.rand(10, 10).astype("float32") + y_data = np.random.rand(1, 10).astype("float32") + inputs = {"x": x_data, "y": y_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"], inputs["y"]] + compile_and_run(func, input_list, output_list) + + +def test_subtract(): + i = relay.var("i", shape=[], dtype="int32") + sub = relay.subtract(i, relay.const(1, dtype="int32")) + func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32")) + i_data = np.array(1, dtype="int32") + inputs = {"i": i_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["i"]] + compile_and_run(func, input_list, output_list) + + +def test_tuple_output(): + x = relay.var("x", shape=(6, 9)) + y = relay.split(x, 3).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + c = relay.TupleGetItem(y, 2) + out = relay.Tuple([a, b]) + func = relay.Function([x], out) + x_data = np.random.rand(6, 9).astype("float32") + inputs = {"x": x_data} + output_list = generate_ref_data(func, inputs) + input_list = [inputs["x"]] + compile_and_run(func, input_list, output_list) + + +def test_mobilenet(): + mod, params = testing.mobilenet.get_workload(batch_size=1) + data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] + data = np.random.uniform(size=data_shape).astype("float32") + inputs = {"data": data} + output_list = generate_ref_data(mod, inputs, params) + input_list = [inputs["data"]] + compile_and_run(mod, input_list, output_list, params) + + +class CcompilerAnnotator(ExprMutator): + """ + This is used to create external functions for ccompiler. + A simple annotator that creates the following program: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(CcompilerAnnotator, self).__init__() + self.in_compiler = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_compiler == 1: + lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") + rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") + op = relay.add(lhs, rhs) + self.in_compiler = 2 + return op + elif call.op.name == "subtract": + if self.in_compiler == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_compiler = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + op = relay.multiply(lhs, rhs) + if self.in_compiler == 2: + op = compiler_end(op, "ccompiler") + self.in_compiler = 0 + return op + return super().visit_call(call) + + +def test_byoc_utvm(): + """This is a simple test case to check BYOC capabilities of AOT""" + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + w2 = relay.var("w2", shape=(10, 10)) + w3 = relay.var("w3", shape=(10, 10)) + w4 = relay.var("w4", shape=(10, 10)) + w5 = relay.var("w5", shape=(10, 10)) + w6 = relay.var("w6", shape=(10, 10)) + w7 = relay.var("w7", shape=(10, 10)) + + # C compiler + z0 = relay.add(x, w0) + p0 = relay.subtract(z0, w1) + q0 = relay.multiply(p0, w2) + + z1 = relay.add(x, w3) + p1 = relay.subtract(z1, w4) + q1 = relay.multiply(p1, w5) + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((q0, q1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = tvm.IRModule() + ann = CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = tvm.relay.transform.PartitionGraph()(mod) + mod = tvm.relay.transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype("float32") + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype("float32")) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + output_list = generate_ref_data(mod, map_inputs) + input_list = [map_inputs["x"]] + input_list.extend([map_inputs["w{}".format(i)] for i in range(8)]) + compile_and_run(mod, input_list, output_list) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 8e6fe298351e..06623e0baa24 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -133,10 +133,12 @@ def test_plan_memory(): smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() + storage_sizes = {} for k, v in smap.items(): - assert len(v) == 2 + assert len(v) == 3 for x in v[0]: storage_ids.add(x.value) + storage_sizes[x.value] = v[2] for x in v[1]: device_types.add(x.value) @@ -145,6 +147,15 @@ def test_plan_memory(): # two alternating temporary space. assert len(storage_ids) == 4 assert len(device_types) == 1 + assert len(storage_sizes) == 4 + + # Check the specific size of each sid + assert ( + storage_sizes[0][0] == 40 + and storage_sizes[1][0] == 4 + and storage_sizes[2][0] == 4 + and storage_sizes[3][0] == 40 + ) def test_reshape_nop(): @@ -162,7 +173,7 @@ def test_reshape_nop(): func = relay.Function([x], relay.Tuple([z0, z1, z2])) x_data = np.random.rand(10, 4).astype("float32") graph = relay.build(tvm.IRModule.from_expr(func), "llvm") - graph_json_str = graph.get_json() + graph_json_str = graph.get_graph_json() graph_json = json.loads(graph_json_str) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 9f6d88e47f0b..be92ef200c31 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -353,7 +353,7 @@ def test_load_params_with_constants_in_ext_codegen(): graph_module = relay.build(mod, target="llvm", params=params) lib = update_lib(graph_module.get_lib()) - rt_mod = tvm.contrib.graph_executor.create(graph_module.get_json(), lib, tvm.cpu(0)) + rt_mod = tvm.contrib.graph_executor.create(graph_module.get_graph_json(), lib, tvm.cpu(0)) rt_mod.load_params(runtime.save_param_dict(graph_module.get_params())) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index a9c31f5ccedd..abf795cd46cc 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -266,7 +266,7 @@ def check_storage_and_device_types(): storage_ids = [] device_types = [] for _, storage_dev_type in smap.items(): - assert len(storage_dev_type) == 2 + assert len(storage_dev_type) == 3 for sid in storage_dev_type[0]: storage_ids.append(sid.value) for did in storage_dev_type[1]: diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 6d678b8a3753..c6902429c0cd 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -158,7 +158,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { with _make_session(workspace, factory.get_lib()) as sess: graph_mod = tvm.micro.create_local_graph_executor( - factory.get_json(), sess.get_system_lib(), sess.device + factory.get_graph_json(), sess.get_system_lib(), sess.device ) A_data = tvm.nd.array(np.array([2, 3], dtype="uint8"), device=sess.device) assert (A_data.asnumpy() == np.array([2, 3])).all() @@ -226,4 +226,5 @@ def test_platform_timer(): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_graph_executor() +# sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index db6c55bca12a..712bd8d348a2 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -26,7 +26,7 @@ import tvm import tvm.relay -from tvm.relay.backend import graph_executor_factory +from tvm.relay.backend import executor_factory import tvm.runtime.module import tvm.testing from tvm.contrib import utils @@ -170,7 +170,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ @tvm.testing.requires_micro def test_export_model(): module = tvm.support.FrontendTestModule() - factory = graph_executor_factory.GraphExecutorFactoryModule( + factory = executor_factory.GraphExecutorFactoryModule( None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {} ) diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index aac7e497f38f..44f20878b800 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -131,7 +131,7 @@ def test_load_unexpected_params(): graph_module = relay.build(mod, target="llvm", params=params) rt_mod = tvm.contrib.graph_executor.create( - graph_module.get_json(), graph_module.get_lib(), tvm.cpu(0) + graph_module.get_graph_json(), graph_module.get_lib(), tvm.cpu(0) ) new_params = graph_module.get_params() diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 1d80c60de790..f85edfc8d033 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -543,7 +543,7 @@ def test_debug_graph_executor(): debug_g_mod = debug_executor.GraphModuleDebug( complied_graph_lib["debug_create"]("default", dev), [dev], - complied_graph_lib.get_json(), + complied_graph_lib.get_graph_json(), None, ) debug_g_mod.set_input("data", data) diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 1bb3c364df17..ee8032550b39 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -54,7 +54,7 @@ def test_graph_executor(target, dev): mod, params = mlp.get_workload(1) exe = relay.build(mod, target, params=params) - gr = debug_executor.create(exe.get_json(), exe.lib, dev) + gr = debug_executor.create(exe.get_graph_json(), exe.lib, dev) data = np.random.rand(1, 1, 28, 28).astype("float32") report = gr.profile(data=data)