From 0f697461f7d7c68c5d59443a3af9ed4ccc27e6aa Mon Sep 17 00:00:00 2001 From: Zhaoyilunnn Date: Thu, 18 Apr 2024 03:47:01 +0000 Subject: [PATCH] fix: close #160 --- quafu/algorithms/interface_provider.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/quafu/algorithms/interface_provider.py b/quafu/algorithms/interface_provider.py index 04989759..1765b556 100644 --- a/quafu/algorithms/interface_provider.py +++ b/quafu/algorithms/interface_provider.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .interface.torch import TorchTransformer - -PROVIDERS = {"torch": TorchTransformer} - class InterfaceProvider: + _init = False + _providers = {} + @classmethod def get(cls, name: str): - if name not in PROVIDERS: + if not cls._init: + from .interface.torch import TorchTransformer + + cls._providers["torch"] = TorchTransformer + + cls._init = True + + if name not in cls._providers: raise NotImplementedError(f"Unsupported interface: {name}") - return PROVIDERS[name] + return cls._providers[name]