Former-commit-id: 616917bb3be7f71073b56ad8c7bc4e164b08b9b5
This commit is contained in:
hiyouga
2024-03-26 17:26:14 +08:00
parent 04423b916f
commit 3336422760
7 changed files with 36 additions and 31 deletions

View File

@@ -102,6 +102,10 @@ class RLHFArguments:
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
@@ -248,6 +252,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")