-
-
Notifications
You must be signed in to change notification settings - Fork 25.6k
/
Copy path_voting.py
753 lines (613 loc) · 25.4 KB
/
_voting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
"""
Soft Voting/Majority Rule classifier and Voting regressor.
This module contains:
- A Soft Voting/Majority Rule classifier for classification estimators.
- A Voting regressor for regression estimators.
"""
# Authors: Sebastian Raschka <[email protected]>,
# Gilles Louppe <[email protected]>,
# Ramil Nugmanov <[email protected]>
# Mohamed Ali Jamaoui <[email protected]>
#
# License: BSD 3 clause
from abc import abstractmethod
from numbers import Integral
import numpy as np
from ..base import (
ClassifierMixin,
RegressorMixin,
TransformerMixin,
_fit_context,
clone,
)
from ..exceptions import NotFittedError
from ..preprocessing import LabelEncoder
from ..utils import Bunch
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._param_validation import StrOptions
from ..utils.metadata_routing import (
MetadataRouter,
MethodMapping,
_raise_for_params,
_routing_enabled,
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.multiclass import type_of_target
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
_check_feature_names_in,
_deprecate_positional_args,
check_is_fitted,
column_or_1d,
)
from ._base import _BaseHeterogeneousEnsemble, _fit_single_estimator
class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):
"""Base class for voting.
Warning: This class should not be used directly. Use derived classes
instead.
"""
_parameter_constraints: dict = {
"estimators": [list],
"weights": ["array-like", None],
"n_jobs": [None, Integral],
"verbose": ["verbose"],
}
def _log_message(self, name, idx, total):
if not self.verbose:
return None
return f"({idx} of {total}) Processing {name}"
@property
def _weights_not_none(self):
"""Get the weights of not `None` estimators."""
if self.weights is None:
return None
return [w for est, w in zip(self.estimators, self.weights) if est[1] != "drop"]
def _predict(self, X):
"""Collect results from clf.predict calls."""
return np.asarray([est.predict(X) for est in self.estimators_]).T
@abstractmethod
def fit(self, X, y, **fit_params):
"""Get common fit operations."""
names, clfs = self._validate_estimators()
if self.weights is not None and len(self.weights) != len(self.estimators):
raise ValueError(
"Number of `estimators` and weights must be equal; got"
f" {len(self.weights)} weights, {len(self.estimators)} estimators"
)
if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch()
for name in names:
routed_params[name] = Bunch(fit={})
if "sample_weight" in fit_params:
routed_params[name].fit["sample_weight"] = fit_params[
"sample_weight"
]
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_single_estimator)(
clone(clf),
X,
y,
fit_params=routed_params[name]["fit"],
message_clsname="Voting",
message=self._log_message(name, idx + 1, len(clfs)),
)
for idx, (name, clf) in enumerate(zip(names, clfs))
if clf != "drop"
)
self.named_estimators_ = Bunch()
# Uses 'drop' as placeholder for dropped estimators
est_iter = iter(self.estimators_)
for name, est in self.estimators:
current_est = est if est == "drop" else next(est_iter)
self.named_estimators_[name] = current_est
if hasattr(current_est, "feature_names_in_"):
self.feature_names_in_ = current_est.feature_names_in_
return self
def fit_transform(self, X, y=None, **fit_params):
"""Return class labels or probabilities for each estimator.
Return predictions for X for each estimator.
Parameters
----------
X : {array-like, sparse matrix, dataframe} of shape \
(n_samples, n_features)
Input samples.
y : ndarray of shape (n_samples,), default=None
Target values (None for unsupervised transformations).
**fit_params : dict
Additional fit parameters.
Returns
-------
X_new : ndarray array of shape (n_samples, n_features_new)
Transformed array.
"""
return super().fit_transform(X, y, **fit_params)
@property
def n_features_in_(self):
"""Number of features seen during :term:`fit`."""
# For consistency with other estimators we raise a AttributeError so
# that hasattr() fails if the estimator isn't fitted.
try:
check_is_fitted(self)
except NotFittedError as nfe:
raise AttributeError(
"{} object has no n_features_in_ attribute.".format(
self.__class__.__name__
)
) from nfe
return self.estimators_[0].n_features_in_
def _sk_visual_block_(self):
names, estimators = zip(*self.estimators)
return _VisualBlock("parallel", estimators, names=names)
def get_metadata_routing(self):
"""Get metadata routing of this object.
Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.
.. versionadded:: 1.5
Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)
# `self.estimators` is a list of (name, est) tuples
for name, estimator in self.estimators:
router.add(
**{name: estimator},
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
return router
class VotingClassifier(ClassifierMixin, _BaseVoting):
"""Soft Voting/Majority Rule classifier for unfitted estimators.
Read more in the :ref:`User Guide <voting_classifier>`.
.. versionadded:: 0.17
Parameters
----------
estimators : list of (str, estimator) tuples
Invoking the ``fit`` method on the ``VotingClassifier`` will fit clones
of those original estimators that will be stored in the class attribute
``self.estimators_``. An estimator can be set to ``'drop'`` using
:meth:`set_params`.
.. versionchanged:: 0.21
``'drop'`` is accepted. Using None was deprecated in 0.22 and
support was removed in 0.24.
voting : {'hard', 'soft'}, default='hard'
If 'hard', uses predicted class labels for majority rule voting.
Else if 'soft', predicts the class label based on the argmax of
the sums of the predicted probabilities, which is recommended for
an ensemble of well-calibrated classifiers.
weights : array-like of shape (n_classifiers,), default=None
Sequence of weights (`float` or `int`) to weight the occurrences of
predicted class labels (`hard` voting) or class probabilities
before averaging (`soft` voting). Uses uniform weights if `None`.
n_jobs : int, default=None
The number of jobs to run in parallel for ``fit``.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
.. versionadded:: 0.18
flatten_transform : bool, default=True
Affects shape of transform output only when voting='soft'
If voting='soft' and flatten_transform=True, transform method returns
matrix with shape (n_samples, n_classifiers * n_classes). If
flatten_transform=False, it returns
(n_classifiers, n_samples, n_classes).
verbose : bool, default=False
If True, the time elapsed while fitting will be printed as it
is completed.
.. versionadded:: 0.23
Attributes
----------
estimators_ : list of classifiers
The collection of fitted sub-estimators as defined in ``estimators``
that are not 'drop'.
named_estimators_ : :class:`~sklearn.utils.Bunch`
Attribute to access any fitted sub-estimators by name.
.. versionadded:: 0.20
le_ : :class:`~sklearn.preprocessing.LabelEncoder`
Transformer used to encode the labels during fit and decode during
prediction.
classes_ : ndarray of shape (n_classes,)
The classes labels.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying classifier exposes such an attribute when fit.
.. versionadded:: 0.24
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Only defined if the
underlying estimators expose such an attribute when fit.
.. versionadded:: 1.0
See Also
--------
VotingRegressor : Prediction voting regressor.
Examples
--------
>>> import numpy as np
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.ensemble import RandomForestClassifier, VotingClassifier
>>> clf1 = LogisticRegression(random_state=1)
>>> clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
>>> clf3 = GaussianNB()
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> y = np.array([1, 1, 1, 2, 2, 2])
>>> eclf1 = VotingClassifier(estimators=[
... ('lr', clf1), ('rf', clf2), ('gnb', clf3)], voting='hard')
>>> eclf1 = eclf1.fit(X, y)
>>> print(eclf1.predict(X))
[1 1 1 2 2 2]
>>> np.array_equal(eclf1.named_estimators_.lr.predict(X),
... eclf1.named_estimators_['lr'].predict(X))
True
>>> eclf2 = VotingClassifier(estimators=[
... ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
... voting='soft')
>>> eclf2 = eclf2.fit(X, y)
>>> print(eclf2.predict(X))
[1 1 1 2 2 2]
To drop an estimator, :meth:`set_params` can be used to remove it. Here we
dropped one of the estimators, resulting in 2 fitted estimators:
>>> eclf2 = eclf2.set_params(lr='drop')
>>> eclf2 = eclf2.fit(X, y)
>>> len(eclf2.estimators_)
2
Setting `flatten_transform=True` with `voting='soft'` flattens output shape of
`transform`:
>>> eclf3 = VotingClassifier(estimators=[
... ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
... voting='soft', weights=[2,1,1],
... flatten_transform=True)
>>> eclf3 = eclf3.fit(X, y)
>>> print(eclf3.predict(X))
[1 1 1 2 2 2]
>>> print(eclf3.transform(X).shape)
(6, 6)
"""
_parameter_constraints: dict = {
**_BaseVoting._parameter_constraints,
"voting": [StrOptions({"hard", "soft"})],
"flatten_transform": ["boolean"],
}
def __init__(
self,
estimators,
*,
voting="hard",
weights=None,
n_jobs=None,
flatten_transform=True,
verbose=False,
):
super().__init__(estimators=estimators)
self.voting = voting
self.weights = weights
self.n_jobs = n_jobs
self.flatten_transform = flatten_transform
self.verbose = verbose
@_fit_context(
# estimators in VotingClassifier.estimators are not validated yet
prefer_skip_nested_validation=False
)
# TODO(1.7): remove `sample_weight` from the signature after deprecation
# cycle; pop it from `fit_params` before the `_raise_for_params` check and
# reinsert later, for backwards compatibility
@_deprecate_positional_args(version="1.7")
def fit(self, X, y, *, sample_weight=None, **fit_params):
"""Fit the estimators.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vectors, where `n_samples` is the number of samples and
`n_features` is the number of features.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if all underlying estimators
support sample weights.
.. versionadded:: 0.18
**fit_params : dict
Parameters to pass to the underlying estimators.
.. versionadded:: 1.5
Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.
Returns
-------
self : object
Returns the instance itself.
"""
_raise_for_params(fit_params, self, "fit")
y_type = type_of_target(y, input_name="y")
if y_type in ("unknown", "continuous"):
# raise a specific ValueError for non-classification tasks
raise ValueError(
f"Unknown label type: {y_type}. Maybe you are trying to fit a "
"classifier, which expects discrete classes on a "
"regression target with continuous values."
)
elif y_type not in ("binary", "multiclass"):
# raise a NotImplementedError for backward compatibility for non-supported
# classification tasks
raise NotImplementedError(
f"{self.__class__.__name__} only supports binary or multiclass "
"classification. Multilabel and multi-output classification are not "
"supported."
)
self.le_ = LabelEncoder().fit(y)
self.classes_ = self.le_.classes_
transformed_y = self.le_.transform(y)
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
return super().fit(X, transformed_y, **fit_params)
def predict(self, X):
"""Predict class labels for X.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples.
Returns
-------
maj : array-like of shape (n_samples,)
Predicted class labels.
"""
check_is_fitted(self)
if self.voting == "soft":
maj = np.argmax(self.predict_proba(X), axis=1)
else: # 'hard' voting
predictions = self._predict(X)
maj = np.apply_along_axis(
lambda x: np.argmax(np.bincount(x, weights=self._weights_not_none)),
axis=1,
arr=predictions,
)
maj = self.le_.inverse_transform(maj)
return maj
def _collect_probas(self, X):
"""Collect results from clf.predict calls."""
return np.asarray([clf.predict_proba(X) for clf in self.estimators_])
def _check_voting(self):
if self.voting == "hard":
raise AttributeError(
f"predict_proba is not available when voting={repr(self.voting)}"
)
return True
@available_if(_check_voting)
def predict_proba(self, X):
"""Compute probabilities of possible outcomes for samples in X.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples.
Returns
-------
avg : array-like of shape (n_samples, n_classes)
Weighted average probability for each class per sample.
"""
check_is_fitted(self)
avg = np.average(
self._collect_probas(X), axis=0, weights=self._weights_not_none
)
return avg
def transform(self, X):
"""Return class labels or probabilities for X for each estimator.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vectors, where `n_samples` is the number of samples and
`n_features` is the number of features.
Returns
-------
probabilities_or_labels
If `voting='soft'` and `flatten_transform=True`:
returns ndarray of shape (n_samples, n_classifiers * n_classes),
being class probabilities calculated by each classifier.
If `voting='soft' and `flatten_transform=False`:
ndarray of shape (n_classifiers, n_samples, n_classes)
If `voting='hard'`:
ndarray of shape (n_samples, n_classifiers), being
class labels predicted by each classifier.
"""
check_is_fitted(self)
if self.voting == "soft":
probas = self._collect_probas(X)
if not self.flatten_transform:
return probas
return np.hstack(probas)
else:
return self._predict(X)
def get_feature_names_out(self, input_features=None):
"""Get output feature names for transformation.
Parameters
----------
input_features : array-like of str or None, default=None
Not used, present here for API consistency by convention.
Returns
-------
feature_names_out : ndarray of str objects
Transformed feature names.
"""
check_is_fitted(self, "n_features_in_")
if self.voting == "soft" and not self.flatten_transform:
raise ValueError(
"get_feature_names_out is not supported when `voting='soft'` and "
"`flatten_transform=False`"
)
_check_feature_names_in(self, input_features, generate_names=False)
class_name = self.__class__.__name__.lower()
active_names = [name for name, est in self.estimators if est != "drop"]
if self.voting == "hard":
return np.asarray(
[f"{class_name}_{name}" for name in active_names], dtype=object
)
# voting == "soft"
n_classes = len(self.classes_)
names_out = [
f"{class_name}_{name}{i}" for name in active_names for i in range(n_classes)
]
return np.asarray(names_out, dtype=object)
class VotingRegressor(RegressorMixin, _BaseVoting):
"""Prediction voting regressor for unfitted estimators.
A voting regressor is an ensemble meta-estimator that fits several base
regressors, each on the whole dataset. Then it averages the individual
predictions to form a final prediction.
Read more in the :ref:`User Guide <voting_regressor>`.
.. versionadded:: 0.21
Parameters
----------
estimators : list of (str, estimator) tuples
Invoking the ``fit`` method on the ``VotingRegressor`` will fit clones
of those original estimators that will be stored in the class attribute
``self.estimators_``. An estimator can be set to ``'drop'`` using
:meth:`set_params`.
.. versionchanged:: 0.21
``'drop'`` is accepted. Using None was deprecated in 0.22 and
support was removed in 0.24.
weights : array-like of shape (n_regressors,), default=None
Sequence of weights (`float` or `int`) to weight the occurrences of
predicted values before averaging. Uses uniform weights if `None`.
n_jobs : int, default=None
The number of jobs to run in parallel for ``fit``.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
verbose : bool, default=False
If True, the time elapsed while fitting will be printed as it
is completed.
.. versionadded:: 0.23
Attributes
----------
estimators_ : list of regressors
The collection of fitted sub-estimators as defined in ``estimators``
that are not 'drop'.
named_estimators_ : :class:`~sklearn.utils.Bunch`
Attribute to access any fitted sub-estimators by name.
.. versionadded:: 0.20
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying regressor exposes such an attribute when fit.
.. versionadded:: 0.24
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Only defined if the
underlying estimators expose such an attribute when fit.
.. versionadded:: 1.0
See Also
--------
VotingClassifier : Soft Voting/Majority Rule classifier.
Examples
--------
>>> import numpy as np
>>> from sklearn.linear_model import LinearRegression
>>> from sklearn.ensemble import RandomForestRegressor
>>> from sklearn.ensemble import VotingRegressor
>>> from sklearn.neighbors import KNeighborsRegressor
>>> r1 = LinearRegression()
>>> r2 = RandomForestRegressor(n_estimators=10, random_state=1)
>>> r3 = KNeighborsRegressor()
>>> X = np.array([[1, 1], [2, 4], [3, 9], [4, 16], [5, 25], [6, 36]])
>>> y = np.array([2, 6, 12, 20, 30, 42])
>>> er = VotingRegressor([('lr', r1), ('rf', r2), ('r3', r3)])
>>> print(er.fit(X, y).predict(X))
[ 6.8... 8.4... 12.5... 17.8... 26... 34...]
In the following example, we drop the `'lr'` estimator with
:meth:`~VotingRegressor.set_params` and fit the remaining two estimators:
>>> er = er.set_params(lr='drop')
>>> er = er.fit(X, y)
>>> len(er.estimators_)
2
"""
def __init__(self, estimators, *, weights=None, n_jobs=None, verbose=False):
super().__init__(estimators=estimators)
self.weights = weights
self.n_jobs = n_jobs
self.verbose = verbose
@_fit_context(
# estimators in VotingRegressor.estimators are not validated yet
prefer_skip_nested_validation=False
)
# TODO(1.7): remove `sample_weight` from the signature after deprecation cycle;
# pop it from `fit_params` before the `_raise_for_params` check and reinsert later,
# for backwards compatibility
@_deprecate_positional_args(version="1.7")
def fit(self, X, y, *, sample_weight=None, **fit_params):
"""Fit the estimators.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vectors, where `n_samples` is the number of samples and
`n_features` is the number of features.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if all underlying estimators
support sample weights.
**fit_params : dict
Parameters to pass to the underlying estimators.
.. versionadded:: 1.5
Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.
Returns
-------
self : object
Fitted estimator.
"""
_raise_for_params(fit_params, self, "fit")
y = column_or_1d(y, warn=True)
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
return super().fit(X, y, **fit_params)
def predict(self, X):
"""Predict regression target for X.
The predicted regression target of an input sample is computed as the
mean predicted regression targets of the estimators in the ensemble.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples.
Returns
-------
y : ndarray of shape (n_samples,)
The predicted values.
"""
check_is_fitted(self)
return np.average(self._predict(X), axis=1, weights=self._weights_not_none)
def transform(self, X):
"""Return predictions for X for each estimator.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples.
Returns
-------
predictions : ndarray of shape (n_samples, n_classifiers)
Values predicted by each regressor.
"""
check_is_fitted(self)
return self._predict(X)
def get_feature_names_out(self, input_features=None):
"""Get output feature names for transformation.
Parameters
----------
input_features : array-like of str or None, default=None
Not used, present here for API consistency by convention.
Returns
-------
feature_names_out : ndarray of str objects
Transformed feature names.
"""
check_is_fitted(self, "n_features_in_")
_check_feature_names_in(self, input_features, generate_names=False)
class_name = self.__class__.__name__.lower()
return np.asarray(
[f"{class_name}_{name}" for name, est in self.estimators if est != "drop"],
dtype=object,
)