Skip to content

Commit

Permalink
10x generate in ppo w/ zero3
Browse files Browse the repository at this point in the history
  • Loading branch information
lightDev0405 committed May 28, 2024
1 parent 85d0185 commit d0d39d5
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation

from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
Expand Down Expand Up @@ -322,10 +323,10 @@ def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor]
for k, v in batch.items():
batch[k] = v[:, start_index:]

unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)

if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
Expand Down

0 comments on commit d0d39d5

Please sign in to comment.