From d834f578218ee38d4984ec8243b2d77aa6bb65ba Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Fri, 26 Jul 2019 15:33:21 +0800 Subject: [PATCH] fix(service): add doc type to req generator --- gnes/proto/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gnes/proto/__init__.py b/gnes/proto/__init__.py index f2bc400e..5da5917c 100644 --- a/gnes/proto/__init__.py +++ b/gnes/proto/__init__.py @@ -28,7 +28,7 @@ class RequestGenerator: @staticmethod - def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): + def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs): for pi in batch_iterator(data, batch_size): req = gnes_pb2.Request() @@ -42,7 +42,7 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: s start_id += 1 @staticmethod - def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): + def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: int = 1, *args, **kwargs): for pi in batch_iterator(data, batch_size): req = gnes_pb2.Request() req.request_id = str(start_id) @@ -59,7 +59,7 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: s start_id += 1 @staticmethod - def query(query: bytes, top_k: int, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): + def query(query: bytes, top_k: int, start_id: int = 0, doc_type: int = 1, *args, **kwargs): if top_k <= 0: raise ValueError('"top_k: %d" is not a valid number' % top_k)