From 1707fc30806c61a70f73c991ea962e47308dddcb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 19 Feb 2025 16:48:57 +0800 Subject: [PATCH 1/2] fix --- swift/llm/infer/infer_engine/infer_engine.py | 3 +++ swift/llm/sampling/utils.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index fc2fa6ab2..3e8031e14 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -122,6 +122,9 @@ def _update_metrics(result, metrics: Optional[List[Metric]] = None): metric.update(response) return result_origin + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + def infer(self, infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None, diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 97bdfe504..3d48c6a53 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -43,6 +43,9 @@ def get_reward(model: Any, if 'ground_truths' in parameters: gt_param = {'ground_truths': ground_truths} rewards = model(infer_requests, request_config=request_config, **gt_param) + from swift.llm.infer.protocol import ChatCompletionResponse + if isinstance(rewards[0], ChatCompletionResponse): + rewards = [float(r.choices[0].message.content) for r in rewards] arr = [] for reward in rewards: if isinstance(reward, (list, tuple)): From b37f90ffc90dfb1cea519c3fa543054dc1b0ddbf Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 19 Feb 2025 16:49:28 +0800 Subject: [PATCH 2/2] lint --- swift/llm/infer/infer_engine/infer_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 3e8031e14..2934e2dda 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -124,7 +124,7 @@ def _update_metrics(result, metrics: Optional[List[Metric]] = None): def __call__(self, *args, **kwargs): return self.infer(*args, **kwargs) - + def infer(self, infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None,