diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index 503a495956..adcda8a816 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -24,40 +24,29 @@ c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict(method_compile_spec); - core::lowering::LowerInfo lower_info; - for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { - const auto& method_name = it->key(); - auto method = mod.get_method(method_name); - auto graph = method.graph(); - core::lowering::LowerGraph(graph, lower_info); - } auto handles = c10::impl::GenericDict( c10::StringType::get(), c10::getCustomClassType>()); for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { + auto mod_ = mod.clone(); const auto& method_name = it->key(); - auto method = mod.get_method(method_name); - auto g = method.graph(); - auto raw_spec = it->value().toCustomClass(); LOG_DEBUG(raw_spec->stringify()); auto cfg = raw_spec->toInternalCompileSpec(); - auto convert_cfg = std::move(cfg.convert_info); - auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue()); + auto graph_and_ivals = Lower(mod_, method_name, cfg.lower_info); - g = graph_and_ivalues.first; - auto params = graph_and_ivalues.second; + auto g = graph_and_ivals.first; + auto params = graph_and_ivals.second; auto named_params = core::conversion::get_named_params(g->inputs(), params); + auto convert_cfg = std::move(cfg.convert_info); auto device_spec = convert_cfg.engine_settings.device; auto device = core::runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params); auto engine_handle = c10::make_intrusive(it->key(), serialized_engine, device); - handles.insert(method.name(), at::IValue(engine_handle)); + handles.insert(method_name, at::IValue(engine_handle)); } return c10::impl::toGenericDict(handles);