support ORPO

Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
This commit is contained in:
hiyouga
2024-03-31 18:29:50 +08:00
parent 526111a303
commit d764cd8736
22 changed files with 395 additions and 47 deletions

View File

@@ -110,6 +110,10 @@ class RLHFArguments:
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
@@ -209,7 +213,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)