-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathensemble.py
59 lines (44 loc) · 1.57 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Defines the Ensemble class."""
from typing import Callable, List
import numpy as np
import flexs
from flexs.types import SEQUENCES_TYPE
class Ensemble(flexs.Model):
"""
Class to ensemble models or landscapes together.
Attributes:
models (List[flexs.Landscape]): List of landscapes/models being ensembled.
combine_with (Callable[[np.ndarray], np.ndarray]): Function to combine ensemble
predictions.
"""
def __init__(
self,
models: List[flexs.Landscape],
combine_with: Callable[[np.ndarray], np.ndarray] = lambda x: np.mean(x, axis=1),
):
"""
Create ensemble.
Args:
models: List of landscapes/models to ensemble.
combine_with: A function that takes in a matrix of scores
(num_seqs, num_models) and combines ensembled model scores into an
array (num_seqs,).
"""
name = f"Ens({'|'.join(model.name for model in models)})"
super().__init__(name)
self.models = models
self.combine_with = combine_with
def train(self, sequences: SEQUENCES_TYPE, labels: np.ndarray):
"""
Train each model in `self.models`.
Args:
sequences: Training sequences
labels: Training labels
"""
for model in self.models:
model.train(sequences, labels)
def _fitness_function(self, sequences):
scores = np.stack(
[model.get_fitness(sequences) for model in self.models], axis=1
)
return self.combine_with(scores)