Skip to content

Commit

Permalink
Merge pull request #286 from jina-ai/fix-drivers-blob-285
Browse files Browse the repository at this point in the history
fix(drivers): add handling for blob
  • Loading branch information
hanxiao authored Apr 21, 2020
2 parents 0d7be8d + fa1a449 commit fc30d7c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions jina/drivers/craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random

from . import BaseExecutableDriver
from .helper import array2blob, pb_obj2dict
from .helper import array2blob, pb_obj2dict, blob2array


class BaseCraftDriver(BaseExecutableDriver):
Expand All @@ -26,10 +26,16 @@ def __call__(self, *args, **kwargs):
continue
_chunks_to_add = []
for c in d.chunks:
ret = self.exec_fn(**pb_obj2dict(c, self.exec.required_keys))
_args_dict = pb_obj2dict(c, self.exec.required_keys)
if 'blob' in self.exec.required_keys:
_args_dict['blob'] = blob2array(c.blob)
ret = self.exec_fn(**_args_dict)
if isinstance(ret, dict):
for k, v in ret.items():
setattr(c, k, v)
if k == 'blob':
c.blob.CopyFrom(array2blob(v))
else:
setattr(c, k, v)
continue
if isinstance(ret, list):
for chunk_dict in ret:
Expand Down
1 change: 0 additions & 1 deletion jina/executors/indexers/vector/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> Tuple['np.ndar

idx = dist.argsort(axis=1)[:, :top_k]
dist = np.take_along_axis(dist, idx, axis=1)

return self.int2ext_key[idx], dist


Expand Down
1 change: 1 addition & 0 deletions tests/executors/encoders/image/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class MyTestCase(ImageTestCase):
def _get_encoder(self):
self.target_output_dim = 1280
self.input_dim = 224
return ImageTorchEncoder()


Expand Down

0 comments on commit fc30d7c

Please sign in to comment.