Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/NCH and MDM #263

Merged
merged 36 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f214ac6
add distances-visualization module and example
gcattan Mar 23, 2024
58a4c18
add module manifold
gcattan Mar 23, 2024
578081f
complete api.rst
gcattan Mar 23, 2024
10687b7
complete doc
gcattan Mar 23, 2024
785c61f
add documentation for plot_scatter
gcattan Mar 23, 2024
bac51f3
run example on Ci
gcattan Mar 23, 2024
7a136ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2024
dc6305b
remove jupyter notebook
gcattan Mar 23, 2024
ac8ad38
missing in previous ocmmit
gcattan Mar 23, 2024
ab366cf
- improve doc
gcattan Mar 24, 2024
f42c61b
lint issues
gcattan Mar 24, 2024
ca1615a
remove distances.py module
gcattan Mar 26, 2024
881a631
Update pyriemann_qiskit/utils/math.py
gcattan Mar 26, 2024
3b9f984
Update pyriemann_qiskit/visualization/manifold.py
gcattan Mar 26, 2024
ec6f24b
Update pyriemann_qiskit/visualization/manifold.py
gcattan Mar 26, 2024
4cdf9d5
replace spds -> X
gcattan Mar 26, 2024
369ea57
use dev version for pyRiemann when running examples
gcattan Mar 26, 2024
5c5092d
try --ignore-requires-python
gcattan Mar 26, 2024
5b785d3
is it the right option?
gcattan Mar 26, 2024
8c06067
try installing moabb from dev
gcattan Mar 26, 2024
dc640dc
remove distances_visualization.py from Ci/Cd
gcattan Mar 26, 2024
bb2a14d
go back to pyRiemann 0.5. We will just wait for release.
gcattan Mar 26, 2024
f37e54c
fix light_benchmark.yml
gcattan Mar 26, 2024
48a59ba
start in ou classifier
gcattan Mar 28, 2024
85b7720
move classifier to pipelines module
gcattan Apr 4, 2024
4e76b43
Merge branch 'main' into gc/nchandmdm
gcattan Apr 4, 2024
9083de5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
158f8af
revert change to examples
gcattan Apr 4, 2024
befa93f
revert unnecessary changes
gcattan Apr 4, 2024
392dfac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
f512baa
fix lint
gcattan Apr 4, 2024
03cf653
Update pyriemann_qiskit/pipelines.py
gcattan Apr 5, 2024
b78d432
Update pyriemann_qiskit/pipelines.py
gcattan Apr 5, 2024
4e4e004
Update api.rst
gcattan Apr 5, 2024
3dda9ff
Update setup.py
gcattan Apr 5, 2024
cdd9c76
Update Dockerfile
gcattan Apr 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ RUN mkdir /home/mne_data

## Workaround for firestore
RUN pip install protobuf==4.25.3
RUN pip install google_cloud_firestore==2.15.0
RUN pip install google_cloud_firestore==2.16.0
### Missing __init__ file in protobuf
RUN touch /usr/local/lib/python3.9/site-packages/protobuf-4.25.3-py3.9.egg/google/__init__.py
## google.cloud.location is never used in these files, and is missing in path.
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.15.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/client.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/base.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/transports/rest.py'
RUN sed -i 's/from google.cloud.location import locations_pb2//g' '/usr/local/lib/python3.9/site-packages/google_cloud_firestore-2.16.0-py3.9.egg/google/cloud/firestore_v1/services/firestore/async_client.py'

ENTRYPOINT [ "python", "/examples/ERP/classify_P300_bi.py" ]
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Pipelines
QuantumClassifierWithDefaultRiemannianPipeline
QuantumMDMWithRiemannianPipeline
QuantumMDMVotingClassifier
FeaturesUnionClassifier


Ensemble
Expand Down
55 changes: 53 additions & 2 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline, FeatureUnion
from sklearn.ensemble import VotingClassifier
from qiskit_optimization.algorithms import CobylaOptimizer
from pyriemann.estimation import XdawnCovariances, ERPCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.preprocessing import Whitening
from pyriemann.classification import MDM
from pyriemann_qiskit.utils.utils import is_qfunction
from pyriemann_qiskit.utils.filtering import NoDimRed
from pyriemann_qiskit.utils.hyper_params_factory import (
Expand All @@ -16,7 +18,12 @@
gen_two_local,
get_spsa,
)
from pyriemann_qiskit.classification import QuanticVQC, QuanticSVM, QuanticMDM
from pyriemann_qiskit.classification import (
QuanticNCH,
QuanticVQC,
QuanticSVM,
QuanticMDM,
)


class BasePipeline(BaseEstimator, ClassifierMixin, TransformerMixin):
Expand Down Expand Up @@ -497,3 +504,47 @@ def _create_pipe(self):
voting="soft",
)
)


class FeaturesUnionClassifier(BasePipeline):

"""An alias for FeatureUnion + Classifier

Aggregate features generated by different transformers, and
use a classifier (e.g. LDA) in top of it.

Parameters
----------
transformers : List[TransformerMixin], default=[QuanticNCH, MDM]
A list of sklearn transformers.
classifier : ClassifierMixin, default=LDA()
A classifier

Attributes
----------
classes_ : list
list of classes.

Notes
-----
.. versionadded:: 0.2.0

"""

def __init__(
self,
transformers=[
QuanticNCH(quantum=True, subsampling="random", n_jobs=-1),
MDM(metric="logeuclid"),
],
classifier=LDA(),
):
self.transformers = transformers
self.classifier = classifier
BasePipeline.__init__(self, "FeatureUnionClassifier")

def _create_pipe(self):
return make_pipeline(
FeatureUnion([(type(t).__name__, t) for t in self.transformers]),
self.classifier,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'qiskit-aer==0.12.2',
'cvxpy==1.4.2',
'scipy==1.11.4',
'docplex>=2.21.207',
'docplex==2.25.236',
'firebase_admin==6.4.0',
'scikit-learn==1.3.2',
'tqdm',
Expand Down
Loading