Skip to content

Commit 7da199c

Browse files
chaeyeunparkgithub-actions[bot]AmintorDuskomlxd
authored
Change AVX2/512 kernel infrastructure for additional gate/generator operations (#404)
* Adding tests * Auto update version * Fix * Update comment * Small fiX; format * Add tests for AVX2/512 gate helpers * Fix for tidy * Remove some internal classes from doc * Auto update version * Apply suggestions from code review Co-authored-by: Amintor Dusko <[email protected]> * Change set to overloadded functions * Format * Fix test name * Auto update version * Change function mae set to setValue * New black * Apply suggestions from code review Co-authored-by: Lee James O'Riordan <[email protected]> * Apply suggestions from code review Co-authored-by: Lee James O'Riordan <[email protected]> * Auto update version * Enable Dispatcher in C++ only build * Auto update version * Update changelog --------- Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com> Co-authored-by: Amintor Dusko <[email protected]> Co-authored-by: Lee James O'Riordan <[email protected]>
1 parent f02becc commit 7da199c

34 files changed

+1998
-327
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
* Remove explicit Numpy and Scipy requirements.
1313
[(#412)](https://github.com/PennyLaneAI/pennylane-lightning/pull/412)
1414

15+
* Update AVX2/512 kernel infrastructure for additional gate/generator operations.
16+
[(#404)](https://github.com/PennyLaneAI/pennylane-lightning/pull/404)
17+
1518
### Documentation
1619

1720
### Bug fixes

pennylane_lightning/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.29.0-dev5"
19+
__version__ = "0.29.0-dev6"

pennylane_lightning/src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(CMAKE_CXX_STANDARD 20) # At least C++20 is required
66

77
option(ENABLE_WARNINGS "Enable warnings" ON)
88
option(ENABLE_OPENMP "Enable OpenMP" ON)
9+
option(ENABLE_GATE_DISPATCHER "Enable gate kernel dispatching on AVX/AVX2/AVX512" ON)
910

1011
if(ENABLE_CLANG_TIDY)
1112
if(NOT DEFINED CLANG_TIDY_BINARY)

pennylane_lightning/src/gates/cpu_kernels/GateImplementationsAVXCommon.hpp

+18-44
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,7 @@
1616
* Defines kernel functions for all AVX
1717
*/
1818
#pragma once
19-
20-
// General implementations
21-
#include "Macros.hpp"
22-
23-
#ifdef PL_USE_AVX2
24-
#include "avx_common/AVX2Concept.hpp"
25-
#endif
26-
#ifdef PL_USE_AVX512F
27-
#include "avx_common/AVX512Concept.hpp"
28-
#endif
29-
#include "avx_common/ApplyCNOT.hpp"
30-
#include "avx_common/ApplyCZ.hpp"
31-
#include "avx_common/ApplyHadamard.hpp"
32-
#include "avx_common/ApplyIsingXX.hpp"
33-
#include "avx_common/ApplyIsingYY.hpp"
34-
#include "avx_common/ApplyIsingZZ.hpp"
35-
#include "avx_common/ApplyPauliX.hpp"
36-
#include "avx_common/ApplyPauliY.hpp"
37-
#include "avx_common/ApplyPauliZ.hpp"
38-
#include "avx_common/ApplyPhaseShift.hpp"
39-
#include "avx_common/ApplyRX.hpp"
40-
#include "avx_common/ApplyRY.hpp"
41-
#include "avx_common/ApplyRZ.hpp"
42-
#include "avx_common/ApplyS.hpp"
43-
#include "avx_common/ApplySWAP.hpp"
44-
#include "avx_common/ApplySingleQubitOp.hpp"
45-
#include "avx_common/ApplyT.hpp"
19+
#include "avx_common/AVXGateKernels.hpp"
4620
#include "avx_common/SingleQubitGateHelper.hpp"
4721
#include "avx_common/TwoQubitGateHelper.hpp"
4822

@@ -91,7 +65,7 @@ class GateImplementationsAVXCommon
9165
std::is_same_v<PrecisionT, double>,
9266
"Only float and double are supported.");
9367

94-
assert(wires.size() == 1);
68+
PL_ASSERT(wires.size() == 1);
9569
auto helper =
9670
AVXCommon::SingleQubitGateWithoutParamHelper<ApplyPauliXAVX>(
9771
&GateImplementationsLM::applyPauliX);
@@ -110,7 +84,7 @@ class GateImplementationsAVXCommon
11084
std::is_same_v<PrecisionT, double>,
11185
"Only float and double are supported.");
11286

113-
assert(wires.size() == 1);
87+
PL_ASSERT(wires.size() == 1);
11488
auto helper =
11589
AVXCommon::SingleQubitGateWithoutParamHelper<ApplyPauliYAVX>(
11690
&GateImplementationsLM::applyPauliY);
@@ -129,7 +103,7 @@ class GateImplementationsAVXCommon
129103
std::is_same_v<PrecisionT, double>,
130104
"Only float and double are supported.");
131105

132-
assert(wires.size() == 1);
106+
PL_ASSERT(wires.size() == 1);
133107
auto helper =
134108
AVXCommon::SingleQubitGateWithoutParamHelper<ApplyPauliZAVX>(
135109
&GateImplementationsLM::applyPauliZ);
@@ -161,7 +135,7 @@ class GateImplementationsAVXCommon
161135
std::is_same_v<PrecisionT, double>,
162136
"Only float and double are supported.");
163137

164-
assert(wires.size() == 1);
138+
PL_ASSERT(wires.size() == 1);
165139
auto helper = AVXCommon::SingleQubitGateWithoutParamHelper<ApplyTAVX>(
166140
&GateImplementationsLM::applyT);
167141
helper(arr, num_qubits, wires, inverse);
@@ -179,7 +153,7 @@ class GateImplementationsAVXCommon
179153
std::is_same_v<PrecisionT, double>,
180154
"Only float and double are supported.");
181155

182-
assert(wires.size() == 1);
156+
PL_ASSERT(wires.size() == 1);
183157
auto helper =
184158
AVXCommon::SingleQubitGateWithParamHelper<ApplyPhaseShiftAVX,
185159
ParamT>(
@@ -198,7 +172,7 @@ class GateImplementationsAVXCommon
198172
static_assert(std::is_same_v<PrecisionT, float> ||
199173
std::is_same_v<PrecisionT, double>,
200174
"Only float and double are supported.");
201-
assert(wires.size() == 1);
175+
PL_ASSERT(wires.size() == 1);
202176
auto helper =
203177
AVXCommon::SingleQubitGateWithoutParamHelper<ApplyHadamardAVX>(
204178
&GateImplementationsLM::applyHadamard);
@@ -215,7 +189,7 @@ class GateImplementationsAVXCommon
215189
static_assert(std::is_same_v<PrecisionT, float> ||
216190
std::is_same_v<PrecisionT, double>,
217191
"Only float and double are supported.");
218-
assert(wires.size() == 1);
192+
PL_ASSERT(wires.size() == 1);
219193
auto helper =
220194
AVXCommon::SingleQubitGateWithParamHelper<ApplyRXAVX, ParamT>(
221195
&GateImplementationsLM::applyRX);
@@ -232,7 +206,7 @@ class GateImplementationsAVXCommon
232206
static_assert(std::is_same_v<PrecisionT, float> ||
233207
std::is_same_v<PrecisionT, double>,
234208
"Only float and double are supported.");
235-
assert(wires.size() == 1);
209+
PL_ASSERT(wires.size() == 1);
236210
auto helper =
237211
AVXCommon::SingleQubitGateWithParamHelper<ApplyRYAVX, ParamT>(
238212
&GateImplementationsLM::applyRY);
@@ -249,7 +223,7 @@ class GateImplementationsAVXCommon
249223
static_assert(std::is_same_v<PrecisionT, float> ||
250224
std::is_same_v<PrecisionT, double>,
251225
"Only float and double are supported.");
252-
assert(wires.size() == 1);
226+
PL_ASSERT(wires.size() == 1);
253227
auto helper =
254228
AVXCommon::SingleQubitGateWithParamHelper<ApplyRZAVX, ParamT>(
255229
&GateImplementationsLM::applyRZ);
@@ -260,7 +234,7 @@ class GateImplementationsAVXCommon
260234
static void applyRot(std::complex<PrecisionT> *arr, const size_t num_qubits,
261235
const std::vector<size_t> &wires, bool inverse,
262236
ParamT phi, ParamT theta, ParamT omega) {
263-
assert(wires.size() == 1);
237+
PL_ASSERT(wires.size() == 1);
264238

265239
const auto rotMat =
266240
(inverse) ? Gates::getRot<PrecisionT>(-omega, -theta, -phi)
@@ -282,7 +256,7 @@ class GateImplementationsAVXCommon
282256
std::is_same_v<PrecisionT, double>,
283257
"Only float and double are supported.");
284258

285-
assert(wires.size() == 2);
259+
PL_ASSERT(wires.size() == 2);
286260

287261
const AVXCommon::TwoQubitGateWithoutParamHelper<ApplyCZAVX> gate_helper(
288262
&GateImplementationsLM::applyCZ<PrecisionT>);
@@ -302,7 +276,7 @@ class GateImplementationsAVXCommon
302276
std::is_same_v<PrecisionT, double>,
303277
"Only float and double are supported.");
304278

305-
assert(wires.size() == 2);
279+
PL_ASSERT(wires.size() == 2);
306280

307281
const AVXCommon::TwoQubitGateWithoutParamHelper<ApplySWAPAVX>
308282
gate_helper(&GateImplementationsLM::applySWAP<PrecisionT>);
@@ -314,7 +288,7 @@ class GateImplementationsAVXCommon
314288
static void
315289
applyCNOT(std::complex<PrecisionT> *arr, const size_t num_qubits,
316290
const std::vector<size_t> &wires, [[maybe_unused]] bool inverse) {
317-
assert(wires.size() == 2);
291+
PL_ASSERT(wires.size() == 2);
318292

319293
using ApplyCNOTAVX =
320294
AVXCommon::ApplyCNOT<PrecisionT,
@@ -327,7 +301,7 @@ class GateImplementationsAVXCommon
327301
std::is_same_v<PrecisionT, double>,
328302
"Only float and double are supported.");
329303

330-
assert(wires.size() == 2);
304+
PL_ASSERT(wires.size() == 2);
331305

332306
const AVXCommon::TwoQubitGateWithoutParamHelper<ApplyCNOTAVX>
333307
gate_helper(&GateImplementationsLM::applyCNOT<PrecisionT>);
@@ -340,7 +314,7 @@ class GateImplementationsAVXCommon
340314
const size_t num_qubits,
341315
const std::vector<size_t> &wires,
342316
[[maybe_unused]] bool inverse, ParamT angle) {
343-
assert(wires.size() == 2);
317+
PL_ASSERT(wires.size() == 2);
344318

345319
using ApplyIsingXXAVX =
346320
AVXCommon::ApplyIsingXX<PrecisionT,
@@ -362,7 +336,7 @@ class GateImplementationsAVXCommon
362336
const size_t num_qubits,
363337
const std::vector<size_t> &wires,
364338
[[maybe_unused]] bool inverse, ParamT angle) {
365-
assert(wires.size() == 2);
339+
PL_ASSERT(wires.size() == 2);
366340

367341
using ApplyIsingYYAVX =
368342
AVXCommon::ApplyIsingYY<PrecisionT,
@@ -392,7 +366,7 @@ class GateImplementationsAVXCommon
392366
std::is_same_v<PrecisionT, double>,
393367
"Only float and double are supported.");
394368

395-
assert(wires.size() == 2);
369+
PL_ASSERT(wires.size() == 2);
396370

397371
const AVXCommon::TwoQubitGateWithParamHelper<ApplyIsingZZAVX, ParamT>
398372
gate_helper(

pennylane_lightning/src/gates/cpu_kernels/avx_common/AVX2Concept.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@
1818
#pragma once
1919
#include "AVXUtil.hpp"
2020
#include "BitUtil.hpp"
21-
#include "Macros.hpp"
2221
#include "Util.hpp"
2322

2423
#include <immintrin.h>
2524

2625
#include <type_traits>
2726

2827
namespace Pennylane::Gates::AVXCommon {
28+
///@cond DEV
2929
namespace Internal {
3030
template <typename T> struct AVX2Intrinsic {
3131
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
3232
};
3333
template <> struct AVX2Intrinsic<float> { using Type = __m256; };
3434
template <> struct AVX2Intrinsic<double> { using Type = __m256d; };
3535
} // namespace Internal
36+
///@endcond
3637

3738
template <typename T> struct AVX2Concept {
3839
using PrecisionT = T;
@@ -110,6 +111,4 @@ template <typename T> struct AVX2Concept {
110111
}
111112
}
112113
};
113-
template <> struct AVXConcept<float, 8> { using Type = AVX2Concept<float>; };
114-
template <> struct AVXConcept<double, 4> { using Type = AVX2Concept<double>; };
115114
} // namespace Pennylane::Gates::AVXCommon

pennylane_lightning/src/gates/cpu_kernels/avx_common/AVX512Concept.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
#include <type_traits>
2727

2828
namespace Pennylane::Gates::AVXCommon {
29+
///@cond DEV
2930
namespace Internal {
3031
template <typename T> struct AVX512Intrinsic {
3132
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
3233
};
3334
template <> struct AVX512Intrinsic<float> { using Type = __m512; };
3435
template <> struct AVX512Intrinsic<double> { using Type = __m512d; };
3536
} // namespace Internal
37+
///@endcond
3638

3739
template <typename T> struct AVX512Concept {
3840
using PrecisionT = T;
@@ -111,8 +113,4 @@ template <typename T> struct AVX512Concept {
111113
}
112114
};
113115

114-
template <> struct AVXConcept<float, 16> { using Type = AVX512Concept<float>; };
115-
template <> struct AVXConcept<double, 8> {
116-
using Type = AVX512Concept<double>;
117-
};
118116
} // namespace Pennylane::Gates::AVXCommon
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2023 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
/**
15+
* @file
16+
* Defines AVXConcept types
17+
*/
18+
#pragma once
19+
20+
#include "Macros.hpp"
21+
22+
#ifdef PL_USE_AVX2
23+
#include "AVX2Concept.hpp"
24+
#endif
25+
26+
#ifdef PL_USE_AVX512F
27+
#include "AVX512Concept.hpp"
28+
#endif
29+
30+
namespace Pennylane::Gates::AVXCommon {
31+
32+
template <class PrecisionT, size_t packed_size> struct AVXConcept;
33+
34+
#ifdef PL_USE_AVX2
35+
template <> struct AVXConcept<float, 8> { using Type = AVX2Concept<float>; };
36+
template <> struct AVXConcept<double, 4> { using Type = AVX2Concept<double>; };
37+
#endif
38+
39+
#ifdef PL_USE_AVX512F
40+
template <> struct AVXConcept<float, 16> { using Type = AVX512Concept<float>; };
41+
template <> struct AVXConcept<double, 8> {
42+
using Type = AVX512Concept<double>;
43+
};
44+
#endif
45+
46+
template <class PrecisionT, size_t packed_size>
47+
using AVXConceptType = typename AVXConcept<PrecisionT, packed_size>::Type;
48+
49+
} // namespace Pennylane::Gates::AVXCommon
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright 2023 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
/**
15+
* @file
16+
* Include all AVX gate implementations
17+
*/
18+
#pragma once
19+
#include "ApplyCNOT.hpp"
20+
#include "ApplyCZ.hpp"
21+
#include "ApplyHadamard.hpp"
22+
#include "ApplyIsingXX.hpp"
23+
#include "ApplyIsingYY.hpp"
24+
#include "ApplyIsingZZ.hpp"
25+
#include "ApplyPauliX.hpp"
26+
#include "ApplyPauliY.hpp"
27+
#include "ApplyPauliZ.hpp"
28+
#include "ApplyPhaseShift.hpp"
29+
#include "ApplyRX.hpp"
30+
#include "ApplyRY.hpp"
31+
#include "ApplyRZ.hpp"
32+
#include "ApplyS.hpp"
33+
#include "ApplySWAP.hpp"
34+
#include "ApplySingleQubitOp.hpp"
35+
#include "ApplyT.hpp"

0 commit comments

Comments
 (0)