Former-commit-id: 616917bb3be7f71073b56ad8c7bc4e164b08b9b5
This commit is contained in:
hiyouga
2024-03-26 17:26:14 +08:00
parent 04423b916f
commit 3336422760
7 changed files with 36 additions and 31 deletions

View File

@@ -20,12 +20,9 @@ if TYPE_CHECKING:
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True,
**kwargs,
):
@@ -47,10 +44,10 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)

View File

@@ -45,13 +45,10 @@ def run_dpo(
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model,
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,

View File

@@ -5,7 +5,6 @@ from transformers import Trainer
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available