Former-commit-id: 627d1c91e675f1d9ebf47bad123cbbf29821da4d
This commit is contained in:
hiyouga
2024-03-09 02:01:26 +08:00
parent 2f095e2017
commit 43b2ede0f8
7 changed files with 28 additions and 20 deletions

View File

@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,