Skip to content

Commit

Permalink
Merge branch 'main' into compat_vllm_0.7.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 19, 2025
2 parents 18ed3a3 + 0d0aa03 commit d2416d1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 58 deletions.
9 changes: 9 additions & 0 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,15 @@ def _infer(
except Exception as e:
error_list.append((i, e))
continue
if len(batched_inputs) == 0:
if request_config.stream:

def _gen_wrapper():
yield self._add_error_list([], error_list)

return _gen_wrapper()
else:
return self._add_error_list([], error_list)
template_inputs = [inputs.pop('template_inputs') for inputs in batched_inputs]
inputs = to_device(template.data_collator(batched_inputs), self.model.device)
template.debug_logger(inputs) # debug
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/sampling/vanilla_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def generate(self, data):
resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
if resp_list[j].choices[0].message.content:
if not isinstance(resp_list[j], Exception):
resps['choices'].append(resp_list[j].choices[0].message.content)
if resps['choices']:
resp_all.append(resps)
Expand Down
8 changes: 4 additions & 4 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,10 +673,10 @@ def _swift_encode(self, inputs: StdTemplateInputs):
res_context_types: List[ContextType] = []
sep_token = None
if template_meta.auto_add_bos:
all_tokens = self.tokenizer.encode('0')
single_zero = self.tokenizer.encode('0', add_special_tokens=False)
assert len(single_zero) == 1
idx = all_tokens.index(single_zero[0])
all_tokens = self.tokenizer.encode('a')
single_token = self.tokenizer.encode('a', add_special_tokens=False)
assert len(single_token) == 1
idx = all_tokens.index(single_token[0])
bos_token = all_tokens[:idx]
sep_token = all_tokens[idx + 1:]
if bos_token:
Expand Down
60 changes: 13 additions & 47 deletions swift/plugin/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,80 +295,46 @@ def __call__(self, completions, **kwargs) -> List[float]:

class CosineReward(ORM):
# https://arxiv.org/abs/2502.03373
def __init__(
self,
cosine_min_len_value_wrong: float = 0.0,
cosine_max_len_value_wrong: float = -0.5,
cosine_min_len_value_correct: float = 1.0,
cosine_max_len_value_correct: float = 0.5,
cosine_max_len: int = 1000,
):
super().__init__()
import importlib.util
assert importlib.util.find_spec('math_verify') is not None, (
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
def __init__(self,
cosine_min_len_value_wrong: float = 0.0,
cosine_max_len_value_wrong: float = -0.5,
cosine_min_len_value_correct: float = 1.0,
cosine_max_len_value_correct: float = 0.5,
cosine_max_len: int = 1000,
accuracy_orm=None):
self.min_len_value_wrong = cosine_min_len_value_wrong
self.max_len_value_wrong = cosine_max_len_value_wrong
self.min_len_value_correct = cosine_min_len_value_correct
self.max_len_value_correct = cosine_max_len_value_correct
self.max_len = cosine_max_len
self.accuracy_orm = accuracy_orm or MathAccuracy()

@staticmethod
def cosfn(t, T, min_value, max_value):
import math
return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2

def __call__(self, completions, solution, **kwargs) -> List[float]:
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
acc_rewards = self.accuracy_orm(completions, solution, **kwargs)
rewards = []

for content, sol in zip(completions, solution):
gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print('Failed to parse gold solution: ', sol)
continue

answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode='first_match',
)

is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)

for content, acc_reward in zip(completions, acc_rewards):
is_correct = acc_reward >= 1.
if is_correct:
# Swap min/max for correct answers
min_value = self.max_len_value_correct
max_value = self.min_len_value_correct
else:
min_value = self.min_len_value_wrong
max_value = self.max_len_value_wrong

gen_len = len(content)
reward = self.cosfn(gen_len, self.max_len, min_value, max_value)
rewards.append(float(reward))
rewards.append(reward)
return rewards


class RepetitionPenalty(ORM):
# https://arxiv.org/abs/2502.03373
def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0):
super().__init__()
self.ngram_size = repetition_n_grams
self.max_penalty = repetition_max_penalty

Expand Down
13 changes: 7 additions & 6 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from trl import GRPOTrainer as HFGRPOTrainer
from trl.models import unwrap_model_for_generation

from swift.llm import InferRequest, RequestConfig, to_device
from swift.llm import InferRequest, RequestConfig, RowPreprocessor, to_device
from swift.plugin import orms
from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available,
is_vllm_available, is_wandb_available)
Expand Down Expand Up @@ -51,10 +51,11 @@ def __init__(self,
if reward_func in orms:
reward_func_class = orms[reward_func]
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
reward_func_args = [
getattr(args, param) for param in reward_func_args if param not in ['self', 'args', 'kwargs']
]
reward_funcs[i] = reward_func_class(*reward_func_args)
reward_func_kwargs = {
key: getattr(args, key)
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
}
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
elif not callable(reward_func):
raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')

Expand Down Expand Up @@ -326,7 +327,7 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
else:
# Repeat all input columns (but "messages" and "completion") to match the number of generations
reward_kwargs = {key: [example[key] for example in inputs] for key in inputs[0]}
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
output_reward_func = reward_func(completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

Expand Down

0 comments on commit d2416d1

Please sign in to comment.