-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(encoder): add the IncrementalPCAEncoder
- Loading branch information
Showing
4 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,4 @@ torchvision: framework | |
onnx: framework, py37 | ||
onnxruntime: framework, py37 | ||
annoy: index | ||
sklearn: numeric |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
import os | ||
from ...decorators import batching, require_train | ||
|
||
from .. import BaseNumericEncoder | ||
|
||
|
||
class IncrementalPCAEncoder(BaseNumericEncoder): | ||
""" | ||
:class:`IncrementalPCAEncoder` encodes data from an ndarray in size `B x T` into an ndarray in size `B x D`. | ||
.. note:: | ||
:class:`IncrementalPCAEncoder` must be trained before calling ``encode()``. This encoder can be trained in an | ||
incremental way. | ||
""" | ||
def __init__(self, | ||
output_dim: int, | ||
whiten: bool = False, | ||
num_features: int = None, | ||
save_path: str = '', | ||
*args, | ||
**kwargs): | ||
""" | ||
:param output_dim: the output size. | ||
:param whiten: If whiten is false, the data is already considered to be whitened, and no whitening is performed. | ||
:param num_features: the number of input features. If ``num_features`` is None, then ``num_features`` is | ||
inferred from the data | ||
:param encoder_abspath: the absolute saving path of the encoder. If a valid path is given, the encoder will be | ||
loaded from the given path. | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self.output_dim = output_dim | ||
self.whiten = whiten | ||
self.num_features = num_features | ||
self.encoder_abspath = save_path | ||
self.is_trained = False | ||
self._args = args | ||
self._kwargs = kwargs | ||
|
||
def post_init(self): | ||
from sklearn.decomposition import IncrementalPCA | ||
if os.path.exists(self.encoder_abspath): | ||
import pickle | ||
with open(self.encoder_abspath, 'rb') as f: | ||
self.model = pickle.load(f) | ||
self.logger.info('load existing model from {}'.format(self.encoder_abspath)) | ||
else: | ||
self.model = IncrementalPCA( | ||
n_components=self.output_dim, | ||
whiten=self.whiten, | ||
*self._args, | ||
**self._kwargs) | ||
|
||
@batching | ||
def train(self, data: 'np.ndarray', *args, **kwargs): | ||
num_samples, num_features = data.shape | ||
if not self.num_features: | ||
self.num_features = num_features | ||
self._check_num_features(num_features) | ||
if num_samples < 5 * num_features: | ||
self.logger.warning( | ||
'the batch size (={}) is suggested to be 5 * num_features(={}) to provide a balance between ' | ||
'approximation accuracy and memory consumption.'.format(num_samples, num_features)) | ||
self.model.partial_fit(data) | ||
self.is_trained = True | ||
|
||
@require_train | ||
@batching | ||
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray': | ||
_, num_features = data.shape | ||
self._check_num_features(num_features) | ||
return self.model.transform(data) | ||
|
||
def _check_num_features(self, num_features): | ||
if self.num_features != num_features: | ||
raise ValueError( | ||
'the number of features must be consistent. ({} != {})'.format(num_features, self.num_features) | ||
) | ||
|
||
def __getstate__(self): | ||
if not self.encoder_abspath: | ||
self.encoder_abspath = os.path.join(self.current_workspace, "pca.bin") | ||
if os.path.exists(self.encoder_abspath): | ||
self.logger.warning( | ||
'the existed model file will be overrided: {}".format(save_path)') | ||
self.logger.info( | ||
'the model is saved at: {}'.format(self.encoder_abspath)) | ||
import pickle | ||
with open(self.encoder_abspath, 'wb') as f: | ||
pickle.dump(self.model, f) | ||
return super().__getstate__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import os | ||
|
||
from . import JinaTestCase | ||
from jina.executors.encoders.numeric.pca import IncrementalPCAEncoder | ||
from jina.executors import BaseExecutor | ||
|
||
|
||
class MyTestCase(JinaTestCase): | ||
num_features = 28 | ||
output_dim = 2 | ||
|
||
def test_encoding_results(self): | ||
encoder = IncrementalPCAEncoder( | ||
output_dim=self.output_dim, whiten=True, num_features=self.num_features) | ||
train_data = np.random.rand(1000, self.num_features) | ||
encoder.train(train_data) | ||
self.assertTrue(encoder.is_trained) | ||
|
||
test_data = np.random.rand(10, self.num_features) | ||
encoded_data = encoder.encode(test_data) | ||
self.assertEqual(encoded_data.shape, (test_data.shape[0], self.output_dim)) | ||
self.assertIs(type(encoded_data), np.ndarray) | ||
|
||
def test_save_and_load(self): | ||
encoder = IncrementalPCAEncoder( | ||
output_dim=self.output_dim, whiten=True, num_features=self.num_features) | ||
train_data = np.random.rand(1000, self.num_features) | ||
encoder.train(train_data) | ||
test_data = np.random.rand(10, self.num_features) | ||
encoded_data_control = encoder.encode(test_data) | ||
|
||
encoder.touch() | ||
encoder.save() | ||
self.assertTrue(os.path.exists(encoder.save_abspath)) | ||
encoder_loaded = BaseExecutor.load(encoder.save_abspath) | ||
encoded_data_test = encoder_loaded.encode(test_data) | ||
|
||
self.assertEqual( | ||
encoder_loaded.model.n_samples_seen_, | ||
encoder.model.n_samples_seen_) | ||
np.testing.assert_array_equal( | ||
encoded_data_test, encoded_data_control) | ||
self.add_tmpfile( | ||
encoder.config_abspath, encoder.save_abspath, encoder_loaded.config_abspath, encoder_loaded.save_abspath, | ||
encoder.encoder_abspath) | ||
|
||
def test_save_and_load_config(self): | ||
encoder = IncrementalPCAEncoder( | ||
output_dim=self.output_dim, whiten=True, num_features=self.num_features) | ||
encoder.save_config() | ||
self.assertTrue(os.path.exists(encoder.config_abspath)) | ||
|
||
encoder_loaded = BaseExecutor.load_config(encoder.config_abspath) | ||
|
||
self.assertEqual( | ||
encoder_loaded.output_dim, | ||
encoder.output_dim) | ||
|
||
self.add_tmpfile(encoder_loaded.config_abspath, encoder_loaded.save_abspath) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |