From dff492d86d3d989347e3b5554a93f0213f9def9f Mon Sep 17 00:00:00 2001 From: Zhaoyilunnn Date: Tue, 10 Sep 2024 17:44:34 +0800 Subject: [PATCH] fix(algorithms): make qnn forward and backward consistent --- quafu/algorithms/gradients/param_shift.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quafu/algorithms/gradients/param_shift.py b/quafu/algorithms/gradients/param_shift.py index 32fd56fb..220d1305 100644 --- a/quafu/algorithms/gradients/param_shift.py +++ b/quafu/algorithms/gradients/param_shift.py @@ -35,6 +35,8 @@ def __call__(self, obs: Hamiltonian, params: List[float]): estimator (Estimator): estimator to calculate expectation values params (List[float]): params to optimize """ + if self._est._backend != "sim": + return self.grad(obs, params) return self.new_grad(obs, params) def _gen_param_shift_vals(self, params): @@ -46,7 +48,6 @@ def _gen_param_shift_vals(self, params): minus_params = params - offsets * np.pi / 2 return plus_params.tolist() + minus_params.tolist() - # TODO: delete after 0.4.1 def grad(self, obs: Hamiltonian, params: List[float]): """grad.