add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
@override
|
||||
def create_optimizer(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
return optimizer
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@override
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
torch.cat(all_masks)[:, :-1],
|
||||
)
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
|
||||
Reference in New Issue
Block a user