Skip to content

Commit

Permalink
Fix prm in sampler (#3184)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Feb 19, 2025
1 parent 0d0aa03 commit 70e19b4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def _update_metrics(result, metrics: Optional[List[Metric]] = None):
metric.update(response)
return result_origin

def __call__(self, *args, **kwargs):
return self.infer(*args, **kwargs)

def infer(self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/sampling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def get_reward(model: Any,
if 'ground_truths' in parameters:
gt_param = {'ground_truths': ground_truths}
rewards = model(infer_requests, request_config=request_config, **gt_param)
from swift.llm.infer.protocol import ChatCompletionResponse
if isinstance(rewards[0], ChatCompletionResponse):
rewards = [float(r.choices[0].message.content) for r in rewards]
arr = []
for reward in rewards:
if isinstance(reward, (list, tuple)):
Expand Down

0 comments on commit 70e19b4

Please sign in to comment.