Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 21, 2025
1 parent c40ba0c commit bb277ed
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 3 deletions.
3 changes: 1 addition & 2 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def _post_init(self):
if getattr(self, 'default_template', None) is None:
self.default_template = get_template(self.model_meta.template, self.processor)
self._adapters_pool = {}
self.strict = True

def _get_stop_words(self, stop_words: List[Union[str, List[int], None]]) -> List[str]:
stop: List[str] = []
Expand Down Expand Up @@ -83,7 +82,7 @@ async def _batch_run(tasks):
if output is None or isinstance(output, Exception):
# is_finished
if isinstance(output, Exception):
if self.strict:
if getattr(self, 'strict', True):
raise
outputs[i] = output
n_finished += 1
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _infer(
try:
batched_inputs.append(future.result())
except Exception as e:
if self.strict:
if getattr(self, 'strict', True):
raise
error_list.append((i, e))
continue
Expand Down
1 change: 1 addition & 0 deletions swift/llm/sampling/distill_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, *args, **kwargs):
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.infer_engine.strict = False
self.caches = self.read_cache()

def _prepare_model_tokenizer(self):
Expand Down
1 change: 1 addition & 0 deletions swift/llm/sampling/vanilla_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, *args, **kwargs):
if _Engine:
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
self.infer_engine.default_template = self.template
self.infer_engine.strict = False
self.caches = self.read_cache()

def read_cache(self):
Expand Down
1 change: 1 addition & 0 deletions swift/plugin/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, api_key=None, base_url=None, model=None):
if model is None:
model = 'qwen-plus'
self.infer_engine = InferClient(base_url=base_url, api_key=api_key)
self.infer_engine.strict = False
self.infer_kwargs = {
'model': model,
}
Expand Down

0 comments on commit bb277ed

Please sign in to comment.