Former-commit-id: 627d1c91e675f1d9ebf47bad123cbbf29821da4d
This commit is contained in:
@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reward_model = reward_model
|
||||
self.current_device = get_current_device() # patch for deepspeed training
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
|
||||
Reference in New Issue
Block a user