From ed96baeed19b4e11b6cbc2dcc6776245ba5fab13 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Tue, 31 Jan 2023 07:55:22 +0000 Subject: [PATCH] check tensor numel in PyObject_CheckLongOrToLong --- paddle/fluid/pybind/op_function_common.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 5cdd9a0fa0668..4f9d6b2649270 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -30,6 +30,7 @@ #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/operators/ops_extra_info.h" +#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/phi/common/complex.h" @@ -70,7 +71,8 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) { if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) || PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT - PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT + (PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT + (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; }