Skip to content

Commit

Permalink
refactor(client): remove client input type
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed May 28, 2020
1 parent 82c7c52 commit f14c146
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 41 deletions.
7 changes: 3 additions & 4 deletions jina/clients/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import request
from .grpc import GrpcClient
from .helper import ProgressBar
from ...enums import ClientInputType, ClientMode
from ...enums import ClientMode
from ...excepts import BadClient
from ...logging import default_logger
from ...logging.profile import TimeContext
Expand Down Expand Up @@ -64,14 +64,13 @@ def mode(self, value: ClientMode):
raise ValueError(f'{value} must be one of {ClientMode}')

@staticmethod
def check_input(input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes], Callable] = None,
input_type: ClientInputType = ClientInputType.BUFFER):
def check_input(input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes], Callable] = None):
"""Validate the input_fn and print the first request if success
:param input_fn: the input function
:param input_type: if the input data is in protobuf Document format, or in raw bytes, or data uri
"""
kwargs = {'data': input_fn, 'input_type': input_type}
kwargs = {'data': input_fn}

try:
r = next(getattr(request, 'index')(**kwargs))
Expand Down
5 changes: 1 addition & 4 deletions jina/main/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def set_client_cli_parser(parser=None):
if not parser:
parser = set_base_parser()

from ..enums import ClientInputType, ClientMode
from ..enums import ClientMode

_set_grpc_parser(parser)

Expand All @@ -415,9 +415,6 @@ def set_client_cli_parser(parser=None):
gp1.add_argument('--top-k', type=int,
default=10,
help='top_k results returned in the search mode')
gp1.add_argument('--input-type', choices=list(ClientInputType), default=ClientInputType.BUFFER,
type=ClientInputType.from_string,
help='the type of input data')
gp1.add_argument('--mime-type', type=str,
help='MIME type of the input, useful when input-type is set to BUFFER')
gp1.add_argument('--callback-on-body', action='store_true', default=False,
Expand Down
3 changes: 1 addition & 2 deletions jina/peapods/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .pea import BasePea
from .zmq import AsyncZmqlet, add_envelope
from .. import __stop_msg__
from ..enums import ClientInputType, ClientMode
from ..enums import ClientMode
from ..excepts import NoExplicitMessage, RequestLoopEnd, NoDriverForRequest, BadRequestType
from ..executors import BaseExecutor
from ..logging.base import get_logger
Expand Down Expand Up @@ -240,7 +240,6 @@ def api(mode):
return http_error('"data" field is empty', 406)

content['mode'] = ClientMode.from_string(mode)
content['input_type'] = ClientInputType.from_string(content.get('input_type', 'data_uri'))

results = get_result_in_json(getattr(python.request, mode)(**content))
return Response(asyncio.run(results),
Expand Down
7 changes: 3 additions & 4 deletions tests/executors/crafters/test_mime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import glob

from jina.enums import ClientInputType
from jina.flow import Flow
from tests import JinaTestCase

Expand Down Expand Up @@ -41,7 +40,7 @@ def test_dummy_seg(self):
def test_any_file(self):
f = Flow().add(yaml_path='!FilePath2DataURI\nwith: {base64: true}')
with f:
f.index(input_fn=input_fn2, output_fn=print, input_type=ClientInputType.FILE_PATH)
f.index(input_fn=input_fn2, output_fn=print)

def test_aba(self):
f = (Flow().add(yaml_path='!Buffer2DataURI\nwith: {mimetype: png}')
Expand All @@ -56,9 +55,9 @@ def test_any2buffer(self):
.add(yaml_path='Buffer2DataURI'))

with f:
f.index(input_fn=input_fn3, output_fn=print, input_type=ClientInputType.DATA_URI)
f.index(input_fn=input_fn3, output_fn=print)

# def test_dummy_seg_random(self):
# f = Flow().add(yaml_path='../../yaml/dummy-seg-random.yml')
# with f:
# f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.collect_chunk_id)
# f.index(input_fn=random_docs(10), output_fn=self.collect_chunk_id)
4 changes: 2 additions & 2 deletions tests/executors/crafters/test_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def collect_chunk_id(self, req):
def test_dummy_seg(self):
f = Flow().add(yaml_path='DummySegment')
with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.get_chunk_id)
f.index(input_fn=random_docs(10), output_fn=self.get_chunk_id)

def test_dummy_seg_random(self):
f = Flow().add(yaml_path='../../yaml/dummy-seg-random.yml')
with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.collect_chunk_id)
f.index(input_fn=random_docs(10), output_fn=self.collect_chunk_id)
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_check_input(self):
input_fn = iter([b'1234', b'45467'])
PyClient.check_input(input_fn)
input_fn = iter([Document(), Document()])
PyClient.check_input(input_fn, input_type=ClientInputType.PROTOBUF)
PyClient.check_input(input_fn)
bad_input_fn = iter([b'1234', '45467', [12, 2, 3]])
self.assertRaises(TypeError, PyClient.check_input, bad_input_fn)
bad_input_fn = iter([Document(), None])
Expand Down
18 changes: 9 additions & 9 deletions tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def test_flow_with_one_container_pod(self):
.add(name='dummyEncoder', image=img_name))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_flow_with_one_container_ext_yaml(self):
f = (Flow()
.add(name='dummyEncoder', image=img_name, yaml_path='./mwu-encoder/mwu_encoder_ext.yml'))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_flow_with_replica_container_ext_yaml(self):
f = (Flow()
Expand All @@ -92,9 +92,9 @@ def test_flow_with_replica_container_ext_yaml(self):
replicas=3))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))
f.index(input_fn=random_docs(10))
f.index(input_fn=random_docs(10))

def test_flow_topo1(self):
f = (Flow()
Expand All @@ -105,7 +105,7 @@ def test_flow_topo1(self):
.join(['d3', 'd2']))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_flow_topo_mixed(self):
f = (Flow()
Expand All @@ -117,7 +117,7 @@ def test_flow_topo_mixed(self):
)

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_flow_topo_replicas(self):
f = (Flow()
Expand All @@ -130,14 +130,14 @@ def test_flow_topo_replicas(self):

with f:
f.dry_run()
f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(1000))

def test_container_volume(self):
f = (Flow()
.add(name='dummyEncoder', image=img_name, volumes='./abc', yaml_path='mwu-encoder/mwu_encoder_upd.yml'))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

out_file = './abc/ext-mwu-encoder.bin'
self.assertTrue(os.path.exists(out_file))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_flow_no_container(self):
.add(name='dummyEncoder', yaml_path='mwu-encoder/mwu_encoder.yml'))

with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_flow_yaml_dump(self):
f = Flow(logserver_config='yaml/test-server-config.yml',
Expand All @@ -165,7 +165,7 @@ def test_flow_log_server(self):
def test_shards(self):
f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=3, separated_workspace=True)
with f:
f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF, random_doc_id=False)
f.index(input_fn=random_docs(1000), random_doc_id=False)
with f:
pass
self.add_tmpfile('test-docshard')
Expand All @@ -185,13 +185,13 @@ def validate(req):

f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=replicas, separated_workspace=True)
with f:
f.index(input_fn=random_docs(index_docs), input_type=ClientInputType.PROTOBUF, random_doc_id=False)
f.index(input_fn=random_docs(index_docs), random_doc_id=False)
with f:
pass
f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=replicas,
separated_workspace=True, polling='all', reducing_yaml_path='_merge_topk_docs')
with f:
f.search(input_fn=random_queries(1, index_docs), input_type=ClientInputType.PROTOBUF, random_doc_id=False, output_fn=validate,
f.search(input_fn=random_queries(1, index_docs), random_doc_id=False, output_fn=validate,
callback_on_body=True)
self.add_tmpfile('test-docshard')

Expand Down
10 changes: 5 additions & 5 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_doc_iters(self):
def test_simple_route(self):
f = Flow().add(yaml_path='_forward')
with f:
f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(10))

def test_update_method(self):
a = DummyIndexer(index_filename='test.bin')
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_two_client_route_replicas(self):
# f3 = Flow(optimize_level=FlowOptimizeLevel.FULL).add(yaml_path='_forward', replicas=3)

def start_client(fl):
fl.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
fl.index(input_fn=random_docs(10))

with f1:
self.assertEqual(f1.num_peas, 6)
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_two_client_route(self):
f = Flow().add(yaml_path='_forward')

def start_client(fl):
fl.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF)
fl.index(input_fn=random_docs(10))

with f:
t1 = mp.Process(target=start_client, args=(f,))
Expand All @@ -167,7 +167,7 @@ def start_client(fl):
def test_index(self):
f = Flow().add(yaml_path='yaml/test-index.yml', replicas=3, separated_workspace=True)
with f:
f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(1000))

for j in range(3):
self.assertTrue(os.path.exists(f'test2-{j + 1}/test2.bin'))
Expand All @@ -176,7 +176,7 @@ def test_index(self):

time.sleep(3)
with f:
f.search(input_fn=random_docs(1), input_type=ClientInputType.PROTOBUF, output_fn=get_result, top_k=100)
f.search(input_fn=random_docs(1), output_fn=get_result, top_k=100)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_index_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def start_gateway():
host='localhost', port_grpc=f_args.port_grpc)

with f:
f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(1000))

time.sleep(3)
for j in range(3):
Expand All @@ -123,7 +123,7 @@ def start_gateway():
host='192.168.31.76', port_grpc=44444))

with f:
f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF)
f.index(input_fn=random_docs(1000))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_loadbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def test_lb(self):
yaml_path='SlowWorker',
replicas=10)
with f:
f.index(input_fn=random_docs(100), input_type=ClientInputType.PROTOBUF, batch_size=10)
f.index(input_fn=random_docs(100), batch_size=10)

def test_roundrobin(self):
f = Flow(runtime='process').add(
name='sw',
yaml_path='SlowWorker',
replicas=10, scheduling=SchedulerType.ROUND_ROBIN)
with f:
f.index(input_fn=random_docs(100), input_type=ClientInputType.PROTOBUF, batch_size=10)
f.index(input_fn=random_docs(100), batch_size=10)
4 changes: 2 additions & 2 deletions tests/test_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def f1(self, quant):
yaml_path='_forward').add(
yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward')
with f as fl:
fl.index(random_docs, output_fn=get_output, input_type=ClientInputType.PROTOBUF)
fl.index(random_docs, output_fn=get_output)

def f2(self, quant):
os.environ['JINA_ARRAY_QUANT'] = quant
Expand All @@ -61,7 +61,7 @@ def f2(self, quant):
yaml_path='_forward').add(
yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward')
with f as fl:
fl.index(random_docs, output_fn=get_output, input_type=ClientInputType.PROTOBUF)
fl.index(random_docs, output_fn=get_output)

def test_quant(self):
for j in ('fp32', 'fp16', 'uint8'):
Expand Down

0 comments on commit f14c146

Please sign in to comment.