fix ppo trainer #551
Former-commit-id: 050a5447c191b8c50a0826a0f03bae499bff8b48
This commit is contained in:
@@ -12,7 +12,7 @@ from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import replace_model
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
@@ -152,8 +152,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
|
||||
Reference in New Issue
Block a user