Skip to content

Commit

Permalink
feat(driver): add override as an option to the driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Jun 2, 2020
1 parent aa8d518 commit 401faab
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
40 changes: 36 additions & 4 deletions jina/drivers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,20 @@
from .helper import guess_mime, array2pb, pb2array


class MIMEDriver(BaseDriver):
class BaseConvertDriver(BaseDriver):

def __init__(self, override: bool = False, *args, **kwargs):
"""
:param override: override the value even when exits
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.override = override


class MIMEDriver(BaseConvertDriver):
"""Guessing the MIME type based on the doc content
Can be used before/after :class:`DocCraftDriver` to fill MIME type
Expand Down Expand Up @@ -43,6 +56,10 @@ def __call__(self, *args, **kwargs):
for d in self.req.docs:
# mime_type may be a file extension
m_type = d.mime_type

if m_type and not self.override:
continue

if m_type and (m_type not in mimetypes.types_map.values()):
m_type = mimetypes.guess_type(f'*.{m_type}')[0]

Expand All @@ -66,15 +83,17 @@ def __call__(self, *args, **kwargs):
self.logger.warning(f'can not determine the MIME type, set to default {self.default_mime}')


class Buffer2NdArray(BaseDriver):
class Buffer2NdArray(BaseConvertDriver):
"""Convert buffer to numpy array"""

def __call__(self, *args, **kwargs):
for d in self.req.docs:
if d.blob and not self.override:
continue
d.blob.CopyFrom(array2pb(np.frombuffer(d.buffer)))


class Blob2PngURI(BaseDriver):
class Blob2PngURI(BaseConvertDriver):
"""Simple DocCrafter used in :command:`jina hello-world`,
it reads ``buffer`` into base64 png and stored in ``uri``"""

Expand All @@ -85,6 +104,9 @@ def __init__(self, width: int = 28, height: int = 28, *args, **kwargs):

def __call__(self, *args, **kwargs):
for d in self.req.docs:
if d.uri and not self.override:
continue

arr = pb2array(d.blob)
pixels = []
for p in arr[::-1]:
Expand All @@ -111,12 +133,16 @@ def png_pack(png_tag, data):
d.uri = 'data:image/png;base64,' + base64.b64encode(png_bytes).decode()


class URI2Buffer(BaseDriver):
class URI2Buffer(BaseConvertDriver):
""" Convert local file path, remote URL doc to a buffer doc.
"""

def __call__(self, *args, **kwargs):
for d in self.req.docs:

if d.buffer and not self.override:
continue

if urllib.parse.urlparse(d.uri).scheme in {'http', 'https', 'data'}:
page = urllib.request.Request(d.uri, headers={'User-Agent': 'Mozilla/5.0'})
tmp = urllib.request.urlopen(page)
Expand Down Expand Up @@ -144,6 +170,9 @@ def __init__(self, charset: str = 'utf-8', base64: bool = False, *args, **kwargs
def __call__(self, *args, **kwargs):
super().__call__()
for d in self.req.docs:
if d.uri and not self.override:
continue

if d.uri and urllib.parse.urlparse(d.uri).scheme == 'data':
pass
else:
Expand All @@ -169,6 +198,9 @@ class Buffer2URI(URI2DataURI):

def __call__(self, *args, **kwargs):
for d in self.req.docs:
if d.uri and not self.override:
continue

if d.uri and urllib.parse.urlparse(d.uri).scheme == 'data':
pass
else:
Expand Down
19 changes: 19 additions & 0 deletions jina/resources/executors.requests.DocURIPbIndexer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
on:
ControlRequest:
- !ControlReqDriver {}
SearchRequest:
- !KVSearchDriver
with:
level: doc
- !TopKFilterDriver {}
IndexRequest:
- !Buffer2URI {}
- !PruneDriver
with:
level: doc
pruned:
- chunks
- buffer
- !KVIndexDriver
with:
level: doc

0 comments on commit 401faab

Please sign in to comment.