-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* push example with alexMI * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update Dockerfile * Update classify_alexmi_with_quantum_pipeline.py --------- Co-authored-by: Gregoire Cattan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
23a2a6a
commit 35187e1
Showing
1 changed file
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
==================================================================== | ||
Classification of MI datasets from MOABB using MDM | ||
and quantum-enhanced MDM | ||
==================================================================== | ||
This example demonstrates how to use quantum pipeline on a MI dataset. | ||
pip install moabb==0.5.0 | ||
""" | ||
# Author: Gregoire Cattan | ||
# Modified from ERP/classify_P300_bi_quantum_mdm.py | ||
# License: BSD (3-clause) | ||
|
||
from matplotlib import pyplot as plt | ||
import warnings | ||
import seaborn as sns | ||
from moabb import set_log_level | ||
from moabb.datasets import AlexMI | ||
from moabb.evaluations import WithinSessionEvaluation | ||
from moabb.paradigms import MotorImagery | ||
|
||
# inject convex distance and mean to pyriemann (if not done already) | ||
from pyriemann_qiskit.utils import distance, mean # noqa | ||
from pyriemann_qiskit.pipelines import ( | ||
QuantumMDMWithRiemannianPipeline, | ||
) | ||
|
||
from sklearn.pipeline import make_pipeline | ||
from pyriemann.estimation import ERPCovariances | ||
from pyriemann.classification import MDM | ||
|
||
print(__doc__) | ||
|
||
############################################################################## | ||
# getting rid of the warnings about the future | ||
warnings.simplefilter(action="ignore", category=FutureWarning) | ||
warnings.simplefilter(action="ignore", category=RuntimeWarning) | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
set_log_level("info") | ||
|
||
############################################################################## | ||
# Initialization | ||
# ---------------- | ||
# | ||
# 1) Create paradigm | ||
# 2) Load datasets | ||
|
||
paradigm = MotorImagery(events=["feet", "right_hand"], n_classes=2) | ||
|
||
datasets = [AlexMI()] | ||
|
||
# reduce the number of subjects | ||
n_subjects = 2 | ||
title = "Datasets: " | ||
for dataset in datasets: | ||
title = title + " " + dataset.code | ||
dataset.subject_list = dataset.subject_list[0:n_subjects] | ||
|
||
############################################################################## | ||
# Create Pipelines | ||
# ---------------- | ||
# | ||
# Pipelines must be a dict of sklearn pipeline transformer. | ||
|
||
pipelines = {} | ||
|
||
# Will run QAOA under the hood | ||
pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline( | ||
convex_metric="distance", quantum=True | ||
) | ||
|
||
# Classical baseline for evaluation | ||
pipelines["R-MDM"] = make_pipeline(ERPCovariances(estimator="lwf"), MDM()) | ||
|
||
############################################################################## | ||
# Run evaluation | ||
# ---------------- | ||
# | ||
# Compare the pipelines using a within session evaluation. | ||
|
||
evaluation = WithinSessionEvaluation( | ||
paradigm=paradigm, | ||
datasets=datasets, | ||
overwrite=True, | ||
) | ||
|
||
results = evaluation.process(pipelines) | ||
|
||
print("Averaging the session performance:") | ||
print(results.groupby("pipeline").mean("score")[["score", "time"]]) | ||
|
||
|
||
# ############################################################################## | ||
# # Plot Results | ||
# # ---------------- | ||
# # | ||
# # Here we plot the results to compare two pipelines | ||
|
||
fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) | ||
|
||
sns.stripplot( | ||
data=results, | ||
y="score", | ||
x="pipeline", | ||
ax=ax, | ||
jitter=True, | ||
alpha=0.5, | ||
zorder=1, | ||
palette="Set1", | ||
) | ||
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1").set( | ||
title=title | ||
) | ||
|
||
ax.set_ylabel("ROC AUC") | ||
ax.set_ylim(0.3, 1) | ||
|
||
plt.show() |