Skip to content

Commit

Permalink
[refactor] Unify implementation of ProgramImpl::compile()
Browse files Browse the repository at this point in the history
ghstack-source-id: 0434f8a9ade8e36e8554633f442cae1faf50f8c0
Pull Request resolved: #7698
  • Loading branch information
PGZXB authored and Taichi Gardener committed Apr 3, 2023
1 parent 063d03c commit 9857410
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 78 deletions.
13 changes: 13 additions & 0 deletions taichi/program/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ void ProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
TI_NOT_IMPLEMENTED;
}

FunctionType ProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation (blocked by cc backend)
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled =
mgr.load_or_compile(compile_config, get_device_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void ProgramImpl::dump_cache_data_to_disk() {
auto &mgr = get_kernel_compilation_manager();
mgr.clean_offline_cache(offline_cache::string_to_clean_cache_policy(
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ProgramImpl {
* Codegen to specific backend
*/
virtual FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) = 0;
Kernel *kernel);

/**
* Allocate runtime buffer, e.g result_buffer or backend specific runtime
Expand Down Expand Up @@ -181,6 +181,10 @@ class ProgramImpl {
TI_NOT_IMPLEMENTED;
}

virtual DeviceCapabilityConfig get_device_caps() {
return {};
}

private:
std::unique_ptr<KernelCompilationManager> kernel_com_mgr_;
std::unique_ptr<KernelLauncher> kernel_launcher_;
Expand Down
18 changes: 5 additions & 13 deletions taichi/runtime/program_impls/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@ namespace taichi::lang {
Dx11ProgramImpl::Dx11ProgramImpl(CompileConfig &config) : ProgramImpl(config) {
}

FunctionType Dx11ProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(
compile_config, runtime_->get_ti_device()->get_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void Dx11ProgramImpl::materialize_runtime(KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) {
*result_buffer_ptr =
Expand Down Expand Up @@ -94,6 +81,11 @@ std::unique_ptr<KernelLauncher> Dx11ProgramImpl::make_kernel_launcher() {
return std::make_unique<gfx::KernelLauncher>(std::move(cfg));
}

DeviceCapabilityConfig Dx11ProgramImpl::get_device_caps() {
TI_ASSERT(runtime_);
return runtime_->get_ti_device()->get_caps();
}

} // namespace taichi::lang

#endif
3 changes: 1 addition & 2 deletions taichi/runtime/program_impls/dx/dx_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ namespace taichi::lang {
class Dx11ProgramImpl : public ProgramImpl {
public:
Dx11ProgramImpl(CompileConfig &config);
FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) override;

std::size_t get_snode_num_dynamically_allocated(
SNode *snode,
Expand Down Expand Up @@ -65,6 +63,7 @@ class Dx11ProgramImpl : public ProgramImpl {
protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::shared_ptr<Device> device_{nullptr};
Expand Down
12 changes: 0 additions & 12 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,6 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
cache_data_ = std::make_unique<LlvmOfflineCache>();
}

FunctionType LlvmProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(compile_config, {}, *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
SNodeTree *tree) {
auto *const root = tree->root();
Expand Down
5 changes: 0 additions & 5 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ class LlvmProgramImpl : public ProgramImpl {
/* ---- JIT-Compilation Interfaces ---- */
/* ------------------------------------ */

// TODO(zhanlue): compile-time runtime split for LLVM::CodeGen
// For now, compile = codegen + convert
FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) override;

void compile_snode_tree_types(SNodeTree *tree) override;

// TODO(zhanlue): refactor materialize_snode_tree()
Expand Down
18 changes: 5 additions & 13 deletions taichi/runtime/program_impls/metal/metal_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,6 @@ MetalProgramImpl::MetalProgramImpl(CompileConfig &config)
: ProgramImpl(config) {
}

FunctionType MetalProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(
compile_config, gfx_runtime_->get_ti_device()->get_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void MetalProgramImpl::materialize_runtime(KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) {
*result_buffer_ptr =
Expand Down Expand Up @@ -113,4 +100,9 @@ std::unique_ptr<KernelLauncher> MetalProgramImpl::make_kernel_launcher() {
return std::make_unique<gfx::KernelLauncher>(std::move(cfg));
}

DeviceCapabilityConfig MetalProgramImpl::get_device_caps() {
TI_ASSERT(gfx_runtime_);
return gfx_runtime_->get_ti_device()->get_caps();
}

} // namespace taichi::lang
3 changes: 1 addition & 2 deletions taichi/runtime/program_impls/metal/metal_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ namespace taichi::lang {
class MetalProgramImpl : public ProgramImpl {
public:
explicit MetalProgramImpl(CompileConfig &config);
FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) override;

std::size_t get_snode_num_dynamically_allocated(
SNode *snode,
Expand Down Expand Up @@ -93,6 +91,7 @@ class MetalProgramImpl : public ProgramImpl {
protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::unique_ptr<metal::MetalDevice> embedded_device_{nullptr};
Expand Down
18 changes: 5 additions & 13 deletions taichi/runtime/program_impls/opengl/opengl_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,6 @@ OpenglProgramImpl::OpenglProgramImpl(CompileConfig &config)
: ProgramImpl(config) {
}

FunctionType OpenglProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(
compile_config, runtime_->get_ti_device()->get_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void OpenglProgramImpl::materialize_runtime(KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) {
*result_buffer_ptr =
Expand Down Expand Up @@ -107,4 +94,9 @@ std::unique_ptr<KernelLauncher> OpenglProgramImpl::make_kernel_launcher() {
return std::make_unique<gfx::KernelLauncher>(std::move(cfg));
}

DeviceCapabilityConfig OpenglProgramImpl::get_device_caps() {
TI_ASSERT(runtime_);
return runtime_->get_ti_device()->get_caps();
}

} // namespace taichi::lang
3 changes: 1 addition & 2 deletions taichi/runtime/program_impls/opengl/opengl_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ class OpenglProgramImpl : public ProgramImpl {
public:
explicit OpenglProgramImpl(CompileConfig &config);
~OpenglProgramImpl() override;
FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) override;

std::size_t get_snode_num_dynamically_allocated(
SNode *snode,
Expand Down Expand Up @@ -72,6 +70,7 @@ class OpenglProgramImpl : public ProgramImpl {
protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::shared_ptr<Device> device_{nullptr};
Expand Down
18 changes: 5 additions & 13 deletions taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,6 @@ VulkanProgramImpl::VulkanProgramImpl(CompileConfig &config)
: ProgramImpl(config) {
}

FunctionType VulkanProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(
compile_config, vulkan_runtime_->get_ti_device()->get_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void VulkanProgramImpl::materialize_runtime(KernelProfilerBase *profiler,
uint64 **result_buffer_ptr) {
*result_buffer_ptr =
Expand Down Expand Up @@ -222,6 +209,11 @@ std::unique_ptr<KernelLauncher> VulkanProgramImpl::make_kernel_launcher() {
return std::make_unique<gfx::KernelLauncher>(std::move(cfg));
}

DeviceCapabilityConfig VulkanProgramImpl::get_device_caps() {
TI_ASSERT(vulkan_runtime_);
return vulkan_runtime_->get_ti_device()->get_caps();
}

VulkanProgramImpl::~VulkanProgramImpl() {
vulkan_runtime_.reset();
embedded_device_.reset();
Expand Down
3 changes: 1 addition & 2 deletions taichi/runtime/program_impls/vulkan/vulkan_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class VulkanDeviceCreator;
class VulkanProgramImpl : public ProgramImpl {
public:
explicit VulkanProgramImpl(CompileConfig &config);
FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel) override;

std::size_t get_snode_num_dynamically_allocated(
SNode *snode,
Expand Down Expand Up @@ -101,6 +99,7 @@ class VulkanProgramImpl : public ProgramImpl {
protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::unique_ptr<vulkan::VulkanDeviceCreator> embedded_device_{nullptr};
Expand Down

0 comments on commit 9857410

Please sign in to comment.