From b1b2c4115a34ee5d941b37e53c33749f5a789aa2 Mon Sep 17 00:00:00 2001 From: Nan Wang Date: Wed, 22 Apr 2020 00:36:28 +0800 Subject: [PATCH 1/2] fix(drivers): add handling for blob --- jina/drivers/craft.py | 12 +++++++++--- jina/executors/indexers/vector/numpy.py | 5 +++-- tests/executors/encoders/image/test_torchvision.py | 1 + 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/jina/drivers/craft.py b/jina/drivers/craft.py index 705dce6379dd3..379a42a1c9b64 100644 --- a/jina/drivers/craft.py +++ b/jina/drivers/craft.py @@ -2,7 +2,7 @@ import random from . import BaseExecutableDriver -from .helper import array2blob, pb_obj2dict +from .helper import array2blob, pb_obj2dict, blob2array class BaseCraftDriver(BaseExecutableDriver): @@ -26,10 +26,16 @@ def __call__(self, *args, **kwargs): continue _chunks_to_add = [] for c in d.chunks: - ret = self.exec_fn(**pb_obj2dict(c, self.exec.required_keys)) + _args_dict = pb_obj2dict(c, self.exec.required_keys) + if 'blob' in self.exec.required_keys: + _args_dict['blob'] = blob2array(c.blob) + ret = self.exec_fn(**_args_dict) if isinstance(ret, dict): for k, v in ret.items(): - setattr(c, k, v) + if k == 'blob': + c.blob.CopyFrom(array2blob(v)) + else: + setattr(c, k, v) continue if isinstance(ret, list): for chunk_dict in ret: diff --git a/jina/executors/indexers/vector/numpy.py b/jina/executors/indexers/vector/numpy.py index f19c61541821f..5faa4c99a6ba9 100644 --- a/jina/executors/indexers/vector/numpy.py +++ b/jina/executors/indexers/vector/numpy.py @@ -132,8 +132,9 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> Tuple['np.ndar idx = dist.argsort(axis=1)[:, :top_k] dist = np.take_along_axis(dist, idx, axis=1) - - return self.int2ext_key[idx], dist + _min = dist.min() + _max = dist.max() + return self.int2ext_key[idx], dist #(_max - dist) / (_max - _min) def _ext_arrs(A, B): diff --git a/tests/executors/encoders/image/test_torchvision.py b/tests/executors/encoders/image/test_torchvision.py index 1bcd5478df210..01cbb30731b6d 100644 --- a/tests/executors/encoders/image/test_torchvision.py +++ b/tests/executors/encoders/image/test_torchvision.py @@ -7,6 +7,7 @@ class MyTestCase(ImageTestCase): def _get_encoder(self): self.target_output_dim = 1280 + self.input_dim = 224 return ImageTorchEncoder() From fa1a44954cf0a5967989c13e052d705e698ab46b Mon Sep 17 00:00:00 2001 From: Nan Wang Date: Wed, 22 Apr 2020 00:40:46 +0800 Subject: [PATCH 2/2] fix: undo the changes --- jina/executors/indexers/vector/numpy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jina/executors/indexers/vector/numpy.py b/jina/executors/indexers/vector/numpy.py index 5faa4c99a6ba9..fe0db2cfd9a66 100644 --- a/jina/executors/indexers/vector/numpy.py +++ b/jina/executors/indexers/vector/numpy.py @@ -132,9 +132,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> Tuple['np.ndar idx = dist.argsort(axis=1)[:, :top_k] dist = np.take_along_axis(dist, idx, axis=1) - _min = dist.min() - _max = dist.max() - return self.int2ext_key[idx], dist #(_max - dist) / (_max - _min) + return self.int2ext_key[idx], dist def _ext_arrs(A, B):