From f14c146a3ebae7b5f420cc430e76753ac993208a Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 28 May 2020 09:37:29 +0200 Subject: [PATCH] refactor(client): remove client input type --- jina/clients/python/__init__.py | 7 +++---- jina/main/parser.py | 5 +---- jina/peapods/gateway.py | 3 +-- tests/executors/crafters/test_mime.py | 7 +++---- tests/executors/crafters/test_segmenter.py | 4 ++-- tests/test_client.py | 2 +- tests/test_container.py | 18 +++++++++--------- tests/test_flow.py | 8 ++++---- tests/test_index.py | 10 +++++----- tests/test_index_remote.py | 4 ++-- tests/test_loadbalance.py | 4 ++-- tests/test_quant.py | 4 ++-- 12 files changed, 35 insertions(+), 41 deletions(-) diff --git a/jina/clients/python/__init__.py b/jina/clients/python/__init__.py index ef06508fafd7b..ab0148e924e88 100644 --- a/jina/clients/python/__init__.py +++ b/jina/clients/python/__init__.py @@ -6,7 +6,7 @@ from . import request from .grpc import GrpcClient from .helper import ProgressBar -from ...enums import ClientInputType, ClientMode +from ...enums import ClientMode from ...excepts import BadClient from ...logging import default_logger from ...logging.profile import TimeContext @@ -64,14 +64,13 @@ def mode(self, value: ClientMode): raise ValueError(f'{value} must be one of {ClientMode}') @staticmethod - def check_input(input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes], Callable] = None, - input_type: ClientInputType = ClientInputType.BUFFER): + def check_input(input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes], Callable] = None): """Validate the input_fn and print the first request if success :param input_fn: the input function :param input_type: if the input data is in protobuf Document format, or in raw bytes, or data uri """ - kwargs = {'data': input_fn, 'input_type': input_type} + kwargs = {'data': input_fn} try: r = next(getattr(request, 'index')(**kwargs)) diff --git a/jina/main/parser.py b/jina/main/parser.py index 7f3d2e236e060..fc5d4cd0cf666 100644 --- a/jina/main/parser.py +++ b/jina/main/parser.py @@ -400,7 +400,7 @@ def set_client_cli_parser(parser=None): if not parser: parser = set_base_parser() - from ..enums import ClientInputType, ClientMode + from ..enums import ClientMode _set_grpc_parser(parser) @@ -415,9 +415,6 @@ def set_client_cli_parser(parser=None): gp1.add_argument('--top-k', type=int, default=10, help='top_k results returned in the search mode') - gp1.add_argument('--input-type', choices=list(ClientInputType), default=ClientInputType.BUFFER, - type=ClientInputType.from_string, - help='the type of input data') gp1.add_argument('--mime-type', type=str, help='MIME type of the input, useful when input-type is set to BUFFER') gp1.add_argument('--callback-on-body', action='store_true', default=False, diff --git a/jina/peapods/gateway.py b/jina/peapods/gateway.py index d347d19bcec54..5db0f4ffc9494 100644 --- a/jina/peapods/gateway.py +++ b/jina/peapods/gateway.py @@ -12,7 +12,7 @@ from .pea import BasePea from .zmq import AsyncZmqlet, add_envelope from .. import __stop_msg__ -from ..enums import ClientInputType, ClientMode +from ..enums import ClientMode from ..excepts import NoExplicitMessage, RequestLoopEnd, NoDriverForRequest, BadRequestType from ..executors import BaseExecutor from ..logging.base import get_logger @@ -240,7 +240,6 @@ def api(mode): return http_error('"data" field is empty', 406) content['mode'] = ClientMode.from_string(mode) - content['input_type'] = ClientInputType.from_string(content.get('input_type', 'data_uri')) results = get_result_in_json(getattr(python.request, mode)(**content)) return Response(asyncio.run(results), diff --git a/tests/executors/crafters/test_mime.py b/tests/executors/crafters/test_mime.py index d711db58267f6..120940ce50fea 100644 --- a/tests/executors/crafters/test_mime.py +++ b/tests/executors/crafters/test_mime.py @@ -1,6 +1,5 @@ import glob -from jina.enums import ClientInputType from jina.flow import Flow from tests import JinaTestCase @@ -41,7 +40,7 @@ def test_dummy_seg(self): def test_any_file(self): f = Flow().add(yaml_path='!FilePath2DataURI\nwith: {base64: true}') with f: - f.index(input_fn=input_fn2, output_fn=print, input_type=ClientInputType.FILE_PATH) + f.index(input_fn=input_fn2, output_fn=print) def test_aba(self): f = (Flow().add(yaml_path='!Buffer2DataURI\nwith: {mimetype: png}') @@ -56,9 +55,9 @@ def test_any2buffer(self): .add(yaml_path='Buffer2DataURI')) with f: - f.index(input_fn=input_fn3, output_fn=print, input_type=ClientInputType.DATA_URI) + f.index(input_fn=input_fn3, output_fn=print) # def test_dummy_seg_random(self): # f = Flow().add(yaml_path='../../yaml/dummy-seg-random.yml') # with f: - # f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.collect_chunk_id) + # f.index(input_fn=random_docs(10), output_fn=self.collect_chunk_id) diff --git a/tests/executors/crafters/test_segmenter.py b/tests/executors/crafters/test_segmenter.py index 7738ed319a3ee..fedc4713293d2 100644 --- a/tests/executors/crafters/test_segmenter.py +++ b/tests/executors/crafters/test_segmenter.py @@ -30,9 +30,9 @@ def collect_chunk_id(self, req): def test_dummy_seg(self): f = Flow().add(yaml_path='DummySegment') with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.get_chunk_id) + f.index(input_fn=random_docs(10), output_fn=self.get_chunk_id) def test_dummy_seg_random(self): f = Flow().add(yaml_path='../../yaml/dummy-seg-random.yml') with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF, output_fn=self.collect_chunk_id) + f.index(input_fn=random_docs(10), output_fn=self.collect_chunk_id) diff --git a/tests/test_client.py b/tests/test_client.py index 1b962a08debbc..4594f41e4d3fc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,7 +27,7 @@ def test_check_input(self): input_fn = iter([b'1234', b'45467']) PyClient.check_input(input_fn) input_fn = iter([Document(), Document()]) - PyClient.check_input(input_fn, input_type=ClientInputType.PROTOBUF) + PyClient.check_input(input_fn) bad_input_fn = iter([b'1234', '45467', [12, 2, 3]]) self.assertRaises(TypeError, PyClient.check_input, bad_input_fn) bad_input_fn = iter([Document(), None]) diff --git a/tests/test_container.py b/tests/test_container.py index 479f16ba553e0..75fc506901057 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -75,14 +75,14 @@ def test_flow_with_one_container_pod(self): .add(name='dummyEncoder', image=img_name)) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_flow_with_one_container_ext_yaml(self): f = (Flow() .add(name='dummyEncoder', image=img_name, yaml_path='./mwu-encoder/mwu_encoder_ext.yml')) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_flow_with_replica_container_ext_yaml(self): f = (Flow() @@ -92,9 +92,9 @@ def test_flow_with_replica_container_ext_yaml(self): replicas=3)) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) + f.index(input_fn=random_docs(10)) + f.index(input_fn=random_docs(10)) def test_flow_topo1(self): f = (Flow() @@ -105,7 +105,7 @@ def test_flow_topo1(self): .join(['d3', 'd2'])) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_flow_topo_mixed(self): f = (Flow() @@ -117,7 +117,7 @@ def test_flow_topo_mixed(self): ) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_flow_topo_replicas(self): f = (Flow() @@ -130,14 +130,14 @@ def test_flow_topo_replicas(self): with f: f.dry_run() - f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(1000)) def test_container_volume(self): f = (Flow() .add(name='dummyEncoder', image=img_name, volumes='./abc', yaml_path='mwu-encoder/mwu_encoder_upd.yml')) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) out_file = './abc/ext-mwu-encoder.bin' self.assertTrue(os.path.exists(out_file)) diff --git a/tests/test_flow.py b/tests/test_flow.py index 1c9f111fd3da3..5a0c4743a70b1 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -142,7 +142,7 @@ def test_flow_no_container(self): .add(name='dummyEncoder', yaml_path='mwu-encoder/mwu_encoder.yml')) with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_flow_yaml_dump(self): f = Flow(logserver_config='yaml/test-server-config.yml', @@ -165,7 +165,7 @@ def test_flow_log_server(self): def test_shards(self): f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=3, separated_workspace=True) with f: - f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF, random_doc_id=False) + f.index(input_fn=random_docs(1000), random_doc_id=False) with f: pass self.add_tmpfile('test-docshard') @@ -185,13 +185,13 @@ def validate(req): f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=replicas, separated_workspace=True) with f: - f.index(input_fn=random_docs(index_docs), input_type=ClientInputType.PROTOBUF, random_doc_id=False) + f.index(input_fn=random_docs(index_docs), random_doc_id=False) with f: pass f = Flow().add(name='doc_pb', yaml_path='yaml/test-docpb.yml', replicas=replicas, separated_workspace=True, polling='all', reducing_yaml_path='_merge_topk_docs') with f: - f.search(input_fn=random_queries(1, index_docs), input_type=ClientInputType.PROTOBUF, random_doc_id=False, output_fn=validate, + f.search(input_fn=random_queries(1, index_docs), random_doc_id=False, output_fn=validate, callback_on_body=True) self.add_tmpfile('test-docshard') diff --git a/tests/test_index.py b/tests/test_index.py index 11f5e71ca6771..0d7c184750e07 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -88,7 +88,7 @@ def test_doc_iters(self): def test_simple_route(self): f = Flow().add(yaml_path='_forward') with f: - f.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(10)) def test_update_method(self): a = DummyIndexer(index_filename='test.bin') @@ -120,7 +120,7 @@ def test_two_client_route_replicas(self): # f3 = Flow(optimize_level=FlowOptimizeLevel.FULL).add(yaml_path='_forward', replicas=3) def start_client(fl): - fl.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + fl.index(input_fn=random_docs(10)) with f1: self.assertEqual(f1.num_peas, 6) @@ -152,7 +152,7 @@ def test_two_client_route(self): f = Flow().add(yaml_path='_forward') def start_client(fl): - fl.index(input_fn=random_docs(10), input_type=ClientInputType.PROTOBUF) + fl.index(input_fn=random_docs(10)) with f: t1 = mp.Process(target=start_client, args=(f,)) @@ -167,7 +167,7 @@ def start_client(fl): def test_index(self): f = Flow().add(yaml_path='yaml/test-index.yml', replicas=3, separated_workspace=True) with f: - f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(1000)) for j in range(3): self.assertTrue(os.path.exists(f'test2-{j + 1}/test2.bin')) @@ -176,7 +176,7 @@ def test_index(self): time.sleep(3) with f: - f.search(input_fn=random_docs(1), input_type=ClientInputType.PROTOBUF, output_fn=get_result, top_k=100) + f.search(input_fn=random_docs(1), output_fn=get_result, top_k=100) if __name__ == '__main__': diff --git a/tests/test_index_remote.py b/tests/test_index_remote.py index 0b994341a25a8..4ee0501ae2ab9 100644 --- a/tests/test_index_remote.py +++ b/tests/test_index_remote.py @@ -98,7 +98,7 @@ def start_gateway(): host='localhost', port_grpc=f_args.port_grpc) with f: - f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(1000)) time.sleep(3) for j in range(3): @@ -123,7 +123,7 @@ def start_gateway(): host='192.168.31.76', port_grpc=44444)) with f: - f.index(input_fn=random_docs(1000), input_type=ClientInputType.PROTOBUF) + f.index(input_fn=random_docs(1000)) if __name__ == '__main__': diff --git a/tests/test_loadbalance.py b/tests/test_loadbalance.py index 2288611d8b339..c6816f3c9bff9 100644 --- a/tests/test_loadbalance.py +++ b/tests/test_loadbalance.py @@ -46,7 +46,7 @@ def test_lb(self): yaml_path='SlowWorker', replicas=10) with f: - f.index(input_fn=random_docs(100), input_type=ClientInputType.PROTOBUF, batch_size=10) + f.index(input_fn=random_docs(100), batch_size=10) def test_roundrobin(self): f = Flow(runtime='process').add( @@ -54,4 +54,4 @@ def test_roundrobin(self): yaml_path='SlowWorker', replicas=10, scheduling=SchedulerType.ROUND_ROBIN) with f: - f.index(input_fn=random_docs(100), input_type=ClientInputType.PROTOBUF, batch_size=10) + f.index(input_fn=random_docs(100), batch_size=10) diff --git a/tests/test_quant.py b/tests/test_quant.py index 73db3147cdd41..b6f0c50b7eb59 100644 --- a/tests/test_quant.py +++ b/tests/test_quant.py @@ -52,7 +52,7 @@ def f1(self, quant): yaml_path='_forward').add( yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward') with f as fl: - fl.index(random_docs, output_fn=get_output, input_type=ClientInputType.PROTOBUF) + fl.index(random_docs, output_fn=get_output) def f2(self, quant): os.environ['JINA_ARRAY_QUANT'] = quant @@ -61,7 +61,7 @@ def f2(self, quant): yaml_path='_forward').add( yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward').add(yaml_path='_forward') with f as fl: - fl.index(random_docs, output_fn=get_output, input_type=ClientInputType.PROTOBUF) + fl.index(random_docs, output_fn=get_output) def test_quant(self): for j in ('fp32', 'fp16', 'uint8'):