diff --git a/Dockerfile b/Dockerfile index 56fe626b..412ef204 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,10 +22,10 @@ RUN mkdir /root/mne_data RUN mkdir /home/mne_data ## Workaround for firestore -RUN pip install protobuf==4.24.0rc2 +RUN pip install protobuf==4.24.2 RUN pip install google_cloud_firestore==2.11.1 ### Missing __init__ file in protobuf -RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.0rc2-py3.8.egg/google/__init__.py +RUN touch /usr/local/lib/python3.8/site-packages/protobuf-4.24.2-py3.8-linux-x86_64.egg/google/__init__.py ## google.cloud.location is never used in these files, and is missing in path. RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/client.py' RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.8/site-packages/google_cloud_firestore-2.11.1-py3.8.egg/google/cloud/firestore_v1/services/firestore/transports/base.py' diff --git a/examples/MI/classify_alexmi_with_quantum_pipeline.py b/examples/MI/classify_alexmi_with_quantum_pipeline.py new file mode 100644 index 00000000..fdc3a6ba --- /dev/null +++ b/examples/MI/classify_alexmi_with_quantum_pipeline.py @@ -0,0 +1,122 @@ +""" +==================================================================== +Classification of MI datasets from MOABB using MDM +and quantum-enhanced MDM +==================================================================== + +This example demonstrates how to use quantum pipeline on a MI dataset. + +pip install moabb==0.5.0 + +""" +# Author: Gregoire Cattan +# Modified from ERP/classify_P300_bi_quantum_mdm.py +# License: BSD (3-clause) + +from matplotlib import pyplot as plt +import warnings +import seaborn as sns +from moabb import set_log_level +from moabb.datasets import AlexMI +from moabb.evaluations import WithinSessionEvaluation +from moabb.paradigms import MotorImagery + +# inject convex distance and mean to pyriemann (if not done already) +from pyriemann_qiskit.utils import distance, mean # noqa +from pyriemann_qiskit.pipelines import ( + QuantumMDMWithRiemannianPipeline, +) + +from sklearn.pipeline import make_pipeline +from pyriemann.estimation import ERPCovariances +from pyriemann.classification import MDM + +print(__doc__) + +############################################################################## +# getting rid of the warnings about the future +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) + +warnings.filterwarnings("ignore") + +set_log_level("info") + +############################################################################## +# Initialization +# ---------------- +# +# 1) Create paradigm +# 2) Load datasets + +paradigm = MotorImagery(events=["feet", "right_hand"], n_classes=2) + +datasets = [AlexMI()] + +# reduce the number of subjects +n_subjects = 2 +title = "Datasets: " +for dataset in datasets: + title = title + " " + dataset.code + dataset.subject_list = dataset.subject_list[0:n_subjects] + +############################################################################## +# Create Pipelines +# ---------------- +# +# Pipelines must be a dict of sklearn pipeline transformer. + +pipelines = {} + +# Will run QAOA under the hood +pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline( + convex_metric="distance", quantum=True +) + +# Classical baseline for evaluation +pipelines["R-MDM"] = make_pipeline(ERPCovariances(estimator="lwf"), MDM()) + +############################################################################## +# Run evaluation +# ---------------- +# +# Compare the pipelines using a within session evaluation. + +evaluation = WithinSessionEvaluation( + paradigm=paradigm, + datasets=datasets, + overwrite=True, +) + +results = evaluation.process(pipelines) + +print("Averaging the session performance:") +print(results.groupby("pipeline").mean("score")[["score", "time"]]) + + +# ############################################################################## +# # Plot Results +# # ---------------- +# # +# # Here we plot the results to compare two pipelines + +fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) + +sns.stripplot( + data=results, + y="score", + x="pipeline", + ax=ax, + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1").set( + title=title +) + +ax.set_ylabel("ROC AUC") +ax.set_ylim(0.3, 1) + +plt.show()