Skip to content

Commit

Permalink
Example with Motor Imagery (#177)
Browse files Browse the repository at this point in the history
* push example with alexMI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update Dockerfile

* Update classify_alexmi_with_quantum_pipeline.py

---------

Co-authored-by: Gregoire Cattan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 4, 2023
1 parent 23a2a6a commit 35187e1
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions examples/MI/classify_alexmi_with_quantum_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
====================================================================
Classification of MI datasets from MOABB using MDM
and quantum-enhanced MDM
====================================================================
This example demonstrates how to use quantum pipeline on a MI dataset.
pip install moabb==0.5.0
"""
# Author: Gregoire Cattan
# Modified from ERP/classify_P300_bi_quantum_mdm.py
# License: BSD (3-clause)

from matplotlib import pyplot as plt
import warnings
import seaborn as sns
from moabb import set_log_level
from moabb.datasets import AlexMI
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import MotorImagery

# inject convex distance and mean to pyriemann (if not done already)
from pyriemann_qiskit.utils import distance, mean # noqa
from pyriemann_qiskit.pipelines import (
QuantumMDMWithRiemannianPipeline,
)

from sklearn.pipeline import make_pipeline
from pyriemann.estimation import ERPCovariances
from pyriemann.classification import MDM

print(__doc__)

##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

warnings.filterwarnings("ignore")

set_log_level("info")

##############################################################################
# Initialization
# ----------------
#
# 1) Create paradigm
# 2) Load datasets

paradigm = MotorImagery(events=["feet", "right_hand"], n_classes=2)

datasets = [AlexMI()]

# reduce the number of subjects
n_subjects = 2
title = "Datasets: "
for dataset in datasets:
title = title + " " + dataset.code
dataset.subject_list = dataset.subject_list[0:n_subjects]

##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.

pipelines = {}

# Will run QAOA under the hood
pipelines["mean=logeuclid/distance=convex"] = QuantumMDMWithRiemannianPipeline(
convex_metric="distance", quantum=True
)

# Classical baseline for evaluation
pipelines["R-MDM"] = make_pipeline(ERPCovariances(estimator="lwf"), MDM())

##############################################################################
# Run evaluation
# ----------------
#
# Compare the pipelines using a within session evaluation.

evaluation = WithinSessionEvaluation(
paradigm=paradigm,
datasets=datasets,
overwrite=True,
)

results = evaluation.process(pipelines)

print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])


# ##############################################################################
# # Plot Results
# # ----------------
# #
# # Here we plot the results to compare two pipelines

fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])

sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1").set(
title=title
)

ax.set_ylabel("ROC AUC")
ax.set_ylim(0.3, 1)

plt.show()

0 comments on commit 35187e1

Please sign in to comment.