From 6e84bd785f15a23c562946d667f7244e2c21705b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Feb 2025 13:36:45 +0800 Subject: [PATCH 1/2] update --- swift/llm/infer/infer_engine/pt_engine.py | 2 ++ swift/llm/template/base.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 76a5b7765..36b682012 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -420,6 +420,8 @@ def _infer( except Exception as e: error_list.append((i, e)) continue + if len(batched_inputs) == 0 and len(error_list) > 0: + raise error_list[0][1] template_inputs = [inputs.pop('template_inputs') for inputs in batched_inputs] inputs = to_device(template.data_collator(batched_inputs), self.model.device) template.debug_logger(inputs) # debug diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 1b770f96a..cfde3adbd 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -673,10 +673,10 @@ def _swift_encode(self, inputs: StdTemplateInputs): res_context_types: List[ContextType] = [] sep_token = None if template_meta.auto_add_bos: - all_tokens = self.tokenizer.encode('0') - single_zero = self.tokenizer.encode('0', add_special_tokens=False) - assert len(single_zero) == 1 - idx = all_tokens.index(single_zero[0]) + all_tokens = self.tokenizer.encode('a') + single_token = self.tokenizer.encode('a', add_special_tokens=False) + assert len(single_token) == 1 + idx = all_tokens.index(single_token[0]) bos_token = all_tokens[:idx] sep_token = all_tokens[idx + 1:] if bos_token: From a6db4e7b7809d733febc71ecc516e9e3c470af11 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Feb 2025 13:45:19 +0800 Subject: [PATCH 2/2] update --- swift/llm/infer/infer_engine/pt_engine.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 36b682012..477f220c0 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -420,8 +420,15 @@ def _infer( except Exception as e: error_list.append((i, e)) continue - if len(batched_inputs) == 0 and len(error_list) > 0: - raise error_list[0][1] + if len(batched_inputs) == 0: + if request_config.stream: + + def _gen_wrapper(): + yield self._add_error_list([], error_list) + + return _gen_wrapper() + else: + return self._add_error_list([], error_list) template_inputs = [inputs.pop('template_inputs') for inputs in batched_inputs] inputs = to_device(template.data_collator(batched_inputs), self.model.device) template.debug_logger(inputs) # debug